Skip to content

Training

SHGAT trains on production execution traces. Every time a tool is selected and executed, the outcome feeds back into the model. No synthetic data, no manually curated datasets — the training signal comes from actual use.

Each training example is a tuple of (intent, chosen_tool, outcome) extracted from real executions:

interface TrainingExample {
intentEmbedding: number[]; // 1024D BGE-M3 embedding of the user intent
candidateId: string; // The tool that was actually selected (positive)
outcome: number; // 1 = success, 0 = failure
contextTools: string[]; // Tools active in the current session
negativeCapIds?: string[]; // Other tools (negatives for contrastive loss)
allNegativesSorted?: string[]; // All negatives sorted by similarity (hard -> easy)
}

The model does not need thousands of examples to start producing useful rankings. The minTracesForTraining config (default: 100) sets the cold-start threshold. Below that, the model falls back to pure embedding similarity.

SHGAT uses InfoNCE (Noise-Contrastive Estimation) as its training objective. The idea is straightforward: push the positive pair (intent, chosen_tool) closer together in embedding space, and push negative pairs apart.

For a batch with one positive and N negatives:

L = -log( exp(sim(q, k+) / tau) / (exp(sim(q, k+) / tau) + sum_i exp(sim(q, k_i-) / tau)) )

Where:

  • q = projected intent embedding
  • k+ = positive tool embedding (the one that was selected)
  • k_i- = negative tool embeddings (tools that were not selected)
  • tau = temperature parameter (controls sharpness of the distribution)
  • sim = dot product similarity after K-head attention

Low temperature makes the distribution sharper — the model must be more confident to score the positive higher than close negatives. High temperature is more forgiving.

Temperature starts warm and cools over training. Early on, the model explores broadly (high temperature = soft probabilities). Later, it focuses on fine-grained distinctions (low temperature = sharp probabilities).

SHGAT uses cosine annealing:

tau(t) = tau_end + (tau_start - tau_end) * 0.5 * (1 + cos(pi * t / T))
import { annealTemperature } from "@casys/shgat";
const totalEpochs = 25;
const tauStart = 0.10; // Warm: soft probabilities
const tauEnd = 0.06; // Cool: sharp distinctions
for (let epoch = 0; epoch < totalEpochs; epoch++) {
const tau = annealTemperature(epoch, totalEpochs, tauStart, tauEnd);
// epoch 0: tau = 0.100
// epoch 12: tau = 0.080
// epoch 24: tau = 0.060
}

Cosine annealing slows down at the end, giving the model time to settle into a stable minimum rather than overshooting.

Not all negatives are equally useful. A negative that is obviously irrelevant (e.g., git_commit when the intent is about database queries) provides almost zero gradient signal. The model already scores it low.

Hard negatives are tools that are semantically close to the positive but functionally different. These are the cases where the model needs to develop fine-grained discrimination.

SHGAT pre-sorts all negatives by similarity to the positive:

interface TrainingExample {
// ...
allNegativesSorted?: string[]; // Hard -> easy order
}

During training, curriculum learning selects negatives from different difficulty tiers based on current model accuracy:

AccuracyTierNegatives sampled from
< 0.35EasyLast third (most dissimilar)
0.35 - 0.55MediumMiddle third
> 0.55HardFirst third (most similar)

This prevents the model from getting stuck on examples it cannot yet handle, while progressively increasing difficulty as accuracy improves.

Not all training examples deserve equal attention. Examples where the model makes large errors are more informative than examples it already handles correctly. PER samples training examples proportionally to their error magnitude.

import { PERBuffer, annealBeta } from "@casys/shgat";
// Wrap your training examples in a PER buffer
const buffer = new PERBuffer(trainingExamples, {
alpha: 0.6, // Priority exponent (0 = uniform, 1 = full prioritization)
beta: 0.4, // IS weight correction (annealed to 1.0)
epsilon: 0.01, // Floor to prevent starvation
maxPriority: 1.0, // Initial priority for new examples
});
// Sample a batch -- high-error examples are sampled more often
const { items, indices, weights } = buffer.sample(batchSize);
// After training, update priorities with the new errors
const tdErrors = items.map((ex) => computeError(ex));
buffer.updatePriorities(indices, tdErrors);
// Decay priorities periodically to prevent stale high-priority examples
buffer.decayPriorities(0.9);

The importance sampling weights (weights) correct for the sampling bias. Without them, the model would overfit to hard examples. Beta is annealed from 0.4 to 1.0 over training — partial correction early, full correction at convergence.

A complete training loop with temperature annealing, PER, and curriculum learning:

  1. Collect traces from production executions into TrainingExample[].

  2. Initialize the buffer and trainer:

    import { SHGAT, PERBuffer, annealTemperature, annealBeta } from "@casys/shgat";
    import { AutogradTrainer } from "@casys/shgat/training";
    const shgat = new SHGAT({ numHeads: 16 });
    // ... register nodes, finalize ...
    const buffer = new PERBuffer(trainingExamples);
    const trainer = new AutogradTrainer({
    learningRate: 0.001,
    batchSize: 32,
    temperature: 0.07,
    gradientClip: 1.0,
    l2Lambda: 0.0001,
    });
  3. Run the training loop:

    const totalEpochs = 25;
    for (let epoch = 0; epoch < totalEpochs; epoch++) {
    // Anneal temperature: 0.10 -> 0.06
    const tau = annealTemperature(epoch, totalEpochs, 0.10, 0.06);
    // Anneal beta for IS correction: 0.4 -> 1.0
    const beta = annealBeta(epoch, totalEpochs, 0.4);
    // Sample batch with PER
    const { items, indices, weights } = buffer.sample(32, beta);
    // Train on batch
    const metrics = trainer.trainBatch(items, {
    temperature: tau,
    isWeights: weights,
    });
    // Update priorities from training errors
    const errors = items.map((_, i) => metrics.perExampleLoss?.[i] ?? metrics.loss);
    buffer.updatePriorities(indices, errors);
    // Decay stale priorities every 5 epochs
    if (epoch % 5 === 0) {
    buffer.decayPriorities(0.9);
    }
    console.log(
    `Epoch ${epoch}: loss=${metrics.loss.toFixed(4)} ` +
    `acc=${metrics.accuracy.toFixed(3)} tau=${tau.toFixed(3)}`
    );
    }
  4. Export trained parameters:

    const params = shgat.exportParams();
    await Deno.writeTextFile("shgat-params.json", JSON.stringify(params));