Skip to content

SHGAT Class

The SHGAT class is the main orchestrator for SuperHyperGraph Attention Networks. It handles graph construction, multi-level message passing, K-head scoring, and parameter persistence. Published on JSR.

Use factory functions instead of the constructor directly. They handle adaptive head selection, graph building, and level computation automatically.

Create a SHGAT instance from unified Node objects. This is the recommended API.

import { createSHGAT, type Node } from "@casys/shgat";
const nodes: Node[] = [
{ id: "psql_query", embedding: queryEmb, children: [], level: 0 },
{ id: "psql_exec", embedding: execEmb, children: [], level: 0 },
{ id: "database", embedding: dbEmb, children: ["psql_query", "psql_exec"], level: 1 },
];
const model = createSHGAT(nodes);
// or with config overrides:
const model2 = createSHGAT(nodes, { numHeads: 8, learningRate: 0.01 });
ParameterTypeDescription
nodesNode[]Array of nodes. Leaves have children: [], composites list their children.
configPartial<SHGATConfig>Optional configuration overrides.

Returns: SHGAT — a fully initialized instance with all nodes registered, levels computed, and indices built.

import { SHGAT } from "@casys/shgat";
const shgat = new SHGAT(config?: Partial<SHGATConfig>);

Creates a bare SHGAT instance. You must register nodes manually and call finalizeNodes() before scoring. In most cases, use createSHGAT() instead.

MethodSignatureDescription
registerNode()(node: Node) => voidRegister a unified node. Call finalizeNodes() after all registrations.
finalizeNodes()() => voidRebuild indices and compute hierarchy levels. Call once after registering all nodes.
registerTool()(node: ToolNode) => voidRegister a tool (leaf). Deprecated — use registerNode() with children: [].
registerCapability()(node: CapabilityNode) => voidRegister a capability (hyperedge). Deprecated — use registerNode() with children.
buildFromData()(tools, capabilities) => voidBatch register tools and capabilities from raw data.
hasToolNode()(id: string) => booleanCheck if a tool node exists.
hasCapabilityNode()(id: string) => booleanCheck if a capability node exists.
getToolCount()() => numberNumber of registered tools.
getCapabilityCount()() => numberNumber of registered capabilities.
getToolIds()() => string[]All registered tool IDs.
getCapabilityIds()() => string[]All registered capability IDs.
MethodSignatureDescription
setCooccurrenceData()(data: CooccurrenceEntry[]) => voidSet V-to-V co-occurrence edges for tool embedding enrichment.
getToolIndexMap()() => Map<string, number>Get tool ID to index mapping (used by co-occurrence loader).

The main scoring function. Runs tensor-native forward pass (multi-level message passing) then K-head attention scoring. All tensor operations stay on the native backend until final result conversion.

const ranked = model.scoreNodes(intentEmbedding);
// Filter by level
const toolScores = model.scoreNodes(intentEmbedding, 0); // leaves only
const capScores = model.scoreNodes(intentEmbedding, 1); // composites only
ParameterTypeDescription
intentEmbeddingnumber[]User intent embedding (1024-dim BGE-M3).
levelnumber | undefinedOptional level filter. 0 = leaves, 1 = composites. Omit to score all.

Returns: NodeScore[] — sorted by score descending.

interface NodeScore {
nodeId: string;
score: number;
headScores: number[];
level: number;
}

Convenience wrappers around scoreNodes().

MethodEquivalentDescription
scoreLeaves(intent)scoreNodes(intent, 0)Score leaf nodes (tools) only.
scoreComposites(intent, level?)scoreNodes(intent, level ?? 1)Score composite nodes at a given level.
MethodSignatureDescription
scoreAllCapabilities()(intent: number[], contextToolIds?: string[]) => AttentionResult[]Deprecated. Use scoreNodes(intent, 1).
scoreAllTools()(intent: number[], contextToolIds?: string[]) => Array<{ toolId, score, headScores }>Deprecated. Use scoreNodes(intent, 0).
interface AttentionResult {
capabilityId: string;
score: number;
headScores: number[];
toolAttention: number[];
}
MethodSignatureDescription
predictPathSuccess()(intent: number[], path: string[]) => numberPredict success probability for a tool execution path.
computeAttention()(intent, contextEmbeddings, capId, contextCapIds?) => AttentionResultCompute detailed attention for a single capability.

Execute multi-level message passing (V->E->…->V). Results are cached until the graph changes.

const { H, E, cache } = model.forward();
// H: number[][] — enriched tool embeddings
// E: number[][] — enriched capability embeddings
// cache: ForwardCache — intermediate results

The SHGAT class itself does not perform training. Use AutogradTrainer from the training module, which provides TensorFlow.js automatic differentiation.

import {
AutogradTrainer,
DEFAULT_TRAINER_CONFIG,
type TrainingMetrics,
} from "@casys/shgat";
const trainer = new AutogradTrainer({
...DEFAULT_TRAINER_CONFIG,
numHeads: 16,
embeddingDim: 1024,
learningRate: 0.05,
});
// Set embeddings from your SHGAT graph
trainer.setNodeEmbeddings(embeddings);
// Train on a batch
const metrics: TrainingMetrics = trainer.trainBatch(examples);
console.log(`Loss: ${metrics.loss}, Accuracy: ${metrics.accuracy}`);

These functions are exported for backward compatibility but throw errors at runtime. Use AutogradTrainer instead.

FunctionStatus
trainSHGATOnEpisodes()Deprecated. Throws.
trainSHGATOnEpisodesKHead()Deprecated. Throws.
trainSHGATOnExecution()Deprecated. Throws.
shgat.trainBatchV1KHeadBatched()Deprecated. Throws.

Prioritized Experience Replay for sample-efficient training.

import { PERBuffer, annealBeta, annealTemperature } from "@casys/shgat";
const buffer = new PERBuffer(maxSize);
buffer.add(example, priority);
const { samples, weights } = buffer.sample(batchSize);
const beta = annealBeta(epoch, maxEpochs);

Serialize all learned parameters to a JSON-compatible object for storage.

const serialized = model.exportParams();
// Store in database, file, etc.
await Deno.writeTextFile("params.json", JSON.stringify(serialized));

Returns: Record<string, unknown> — contains config, head parameters, level parameters, V2V parameters, and fusion weights.

Restore parameters from a previously exported object.

const data = JSON.parse(await Deno.readTextFile("params.json"));
model.importParams(data);
ParameterTypeDescription
serializedRecord<string, unknown>Output from exportParams().

Importing parameters invalidates cached tensor parameters. They are recreated lazily on the next scoreNodes() call.

Free GPU memory used by tensor parameters. Call when the SHGAT instance is no longer needed.

model.dispose();

After calling dispose(), the instance can still be used — tensor parameters are recreated lazily on the next scoreNodes() call. But you should call dispose() to avoid GPU memory leaks.

interface SHGATConfig {
// Architecture
numHeads: number; // Attention heads (4-16, adaptive). Default: 16
hiddenDim: number; // Hidden dim for scoring. Default: 1024
headDim: number; // Dim per head (hiddenDim / numHeads). Default: 64
embeddingDim: number; // Embedding dim (BGE-M3: 1024). Default: 1024
numLayers: number; // Message passing layers. Default: 2
mlpHiddenDim: number; // Fusion MLP hidden size. Default: 32
// Training
learningRate: number; // Default: 0.05
batchSize: number; // Default: 32
maxContextLength: number; // Max recent tools in context. Default: 5
// Buffer management
maxBufferSize: number; // PER buffer cap. Default: 50_000
minTracesForTraining: number; // Cold start threshold. Default: 100
// Regularization
dropout: number; // Default: 0.1
l2Lambda: number; // L2 weight. Default: 0.0001
leakyReluSlope: number; // Default: 0.2
depthDecay: number; // Recursive depth decay. Default: 0.8
// Dimension preservation
preserveDim?: boolean; // Keep 1024-dim through MP. Default: true
preserveDimResidual?: number; // Residual blend weight. Default: 0.3
preserveDimResiduals?: number[]; // Per-level residual weights
// Multi-location residuals
v2vResidual?: number; // V2V phase residual. Default: 0
downwardResidual?: number; // Downward phase residual. Default: 0
// Gradient scaling
mpLearningRateScale?: number; // LR multiplier for MP params. Default: 1
// Projection head
useProjectionHead?: boolean; // Default: false
projectionHiddenDim?: number; // Default: 256
projectionOutputDim?: number; // Default: 256
projectionBlendAlpha?: number; // Default: 0.5
projectionTemperature?: number; // Default: 0.07
}

The default configuration object. Uses 16 heads for optimal performance (16 heads x 64 dim = 1024, matching BGE-M3 embeddings).

import { DEFAULT_SHGAT_CONFIG } from "@casys/shgat";

Returns adaptive configuration based on trace count. Deprecated — createSHGAT() uses getAdaptiveHeadsByGraphSize() automatically.

import { getAdaptiveConfig } from "@casys/shgat";
const overrides = getAdaptiveConfig(traceCount);
// Always returns { numHeads: 16, hiddenDim: 1024, headDim: 64, mlpHiddenDim: 32 }

The input format for training data.

interface TrainingExample {
intentEmbedding: number[]; // 1024-dim intent embedding
contextTools: string[]; // Active tool IDs in session
candidateId: string; // Positive capability ID (was executed)
outcome: number; // 1 = success, 0 = failure
negativeCapIds?: string[]; // Negative IDs for contrastive learning
allNegativesSorted?: string[]; // All negatives sorted hard-to-easy (for curriculum)
}
import { setLogger, resetLogger, getLogger, type Logger } from "@casys/shgat";
// Custom logger
setLogger({
info: (msg) => myLogger.info(msg),
warn: (msg) => myLogger.warn(msg),
error: (msg) => myLogger.error(msg),
debug: (msg) => myLogger.debug(msg),
});
// Reset to default console logger
resetLogger();