Training
SHGAT trains on production execution traces. Every time a tool is selected and executed, the outcome feeds back into the model. No synthetic data, no manually curated datasets — the training signal comes from actual use.
Training from production traces
Section titled “Training from production traces”Each training example is a tuple of (intent, chosen_tool, outcome) extracted from real executions:
interface TrainingExample { intentEmbedding: number[]; // 1024D BGE-M3 embedding of the user intent candidateId: string; // The tool that was actually selected (positive) outcome: number; // 1 = success, 0 = failure contextTools: string[]; // Tools active in the current session negativeCapIds?: string[]; // Other tools (negatives for contrastive loss) allNegativesSorted?: string[]; // All negatives sorted by similarity (hard -> easy)}The model does not need thousands of examples to start producing useful rankings. The minTracesForTraining config (default: 100) sets the cold-start threshold. Below that, the model falls back to pure embedding similarity.
InfoNCE Contrastive Loss
Section titled “InfoNCE Contrastive Loss”SHGAT uses InfoNCE (Noise-Contrastive Estimation) as its training objective. The idea is straightforward: push the positive pair (intent, chosen_tool) closer together in embedding space, and push negative pairs apart.
For a batch with one positive and N negatives:
L = -log( exp(sim(q, k+) / tau) / (exp(sim(q, k+) / tau) + sum_i exp(sim(q, k_i-) / tau)) )Where:
q= projected intent embeddingk+= positive tool embedding (the one that was selected)k_i-= negative tool embeddings (tools that were not selected)tau= temperature parameter (controls sharpness of the distribution)sim= dot product similarity after K-head attention
Low temperature makes the distribution sharper — the model must be more confident to score the positive higher than close negatives. High temperature is more forgiving.
Temperature Annealing
Section titled “Temperature Annealing”Temperature starts warm and cools over training. Early on, the model explores broadly (high temperature = soft probabilities). Later, it focuses on fine-grained distinctions (low temperature = sharp probabilities).
SHGAT uses cosine annealing:
tau(t) = tau_end + (tau_start - tau_end) * 0.5 * (1 + cos(pi * t / T))import { annealTemperature } from "@casys/shgat";
const totalEpochs = 25;const tauStart = 0.10; // Warm: soft probabilitiesconst tauEnd = 0.06; // Cool: sharp distinctions
for (let epoch = 0; epoch < totalEpochs; epoch++) { const tau = annealTemperature(epoch, totalEpochs, tauStart, tauEnd); // epoch 0: tau = 0.100 // epoch 12: tau = 0.080 // epoch 24: tau = 0.060}tau0.10 |* | ** | ***0.08 | **** | **** | ****0.06 | ***** +------------------------- 0 12 24 epochCosine annealing slows down at the end, giving the model time to settle into a stable minimum rather than overshooting.
Hard Negative Mining
Section titled “Hard Negative Mining”Not all negatives are equally useful. A negative that is obviously irrelevant (e.g., git_commit when the intent is about database queries) provides almost zero gradient signal. The model already scores it low.
Hard negatives are tools that are semantically close to the positive but functionally different. These are the cases where the model needs to develop fine-grained discrimination.
SHGAT pre-sorts all negatives by similarity to the positive:
interface TrainingExample { // ... allNegativesSorted?: string[]; // Hard -> easy order}During training, curriculum learning selects negatives from different difficulty tiers based on current model accuracy:
| Accuracy | Tier | Negatives sampled from |
|---|---|---|
| < 0.35 | Easy | Last third (most dissimilar) |
| 0.35 - 0.55 | Medium | Middle third |
| > 0.55 | Hard | First third (most similar) |
This prevents the model from getting stuck on examples it cannot yet handle, while progressively increasing difficulty as accuracy improves.
Prioritized Experience Replay (PER)
Section titled “Prioritized Experience Replay (PER)”Not all training examples deserve equal attention. Examples where the model makes large errors are more informative than examples it already handles correctly. PER samples training examples proportionally to their error magnitude.
import { PERBuffer, annealBeta } from "@casys/shgat";
// Wrap your training examples in a PER bufferconst buffer = new PERBuffer(trainingExamples, { alpha: 0.6, // Priority exponent (0 = uniform, 1 = full prioritization) beta: 0.4, // IS weight correction (annealed to 1.0) epsilon: 0.01, // Floor to prevent starvation maxPriority: 1.0, // Initial priority for new examples});
// Sample a batch -- high-error examples are sampled more oftenconst { items, indices, weights } = buffer.sample(batchSize);
// After training, update priorities with the new errorsconst tdErrors = items.map((ex) => computeError(ex));buffer.updatePriorities(indices, tdErrors);
// Decay priorities periodically to prevent stale high-priority examplesbuffer.decayPriorities(0.9);The importance sampling weights (weights) correct for the sampling bias. Without them, the model would overfit to hard examples. Beta is annealed from 0.4 to 1.0 over training — partial correction early, full correction at convergence.
Training loop
Section titled “Training loop”A complete training loop with temperature annealing, PER, and curriculum learning:
-
Collect traces from production executions into
TrainingExample[]. -
Initialize the buffer and trainer:
import { SHGAT, PERBuffer, annealTemperature, annealBeta } from "@casys/shgat";import { AutogradTrainer } from "@casys/shgat/training";const shgat = new SHGAT({ numHeads: 16 });// ... register nodes, finalize ...const buffer = new PERBuffer(trainingExamples);const trainer = new AutogradTrainer({learningRate: 0.001,batchSize: 32,temperature: 0.07,gradientClip: 1.0,l2Lambda: 0.0001,}); -
Run the training loop:
const totalEpochs = 25;for (let epoch = 0; epoch < totalEpochs; epoch++) {// Anneal temperature: 0.10 -> 0.06const tau = annealTemperature(epoch, totalEpochs, 0.10, 0.06);// Anneal beta for IS correction: 0.4 -> 1.0const beta = annealBeta(epoch, totalEpochs, 0.4);// Sample batch with PERconst { items, indices, weights } = buffer.sample(32, beta);// Train on batchconst metrics = trainer.trainBatch(items, {temperature: tau,isWeights: weights,});// Update priorities from training errorsconst errors = items.map((_, i) => metrics.perExampleLoss?.[i] ?? metrics.loss);buffer.updatePriorities(indices, errors);// Decay stale priorities every 5 epochsif (epoch % 5 === 0) {buffer.decayPriorities(0.9);}console.log(`Epoch ${epoch}: loss=${metrics.loss.toFixed(4)} ` +`acc=${metrics.accuracy.toFixed(3)} tau=${tau.toFixed(3)}`);} -
Export trained parameters:
const params = shgat.exportParams();await Deno.writeTextFile("shgat-params.json", JSON.stringify(params));
See Also
Section titled “See Also”- Architecture — K-head attention and message passing
- Persistence — Saving and loading trained parameters