Introduction
When training large language models (LLMs) the most important question is simple: how do we measure whether the model is doing well? For regression you use mean squared error, for classification you might use cross-entropy or hinge loss. But for LLMs — which predict sequences of discrete tokens — the right way to turn “this output feels wrong” into a number you can optimize is a specific kind of probability loss: categorical cross-entropy / negative log likelihood, and the closely related, more interpretable metric perplexity.
This post explains, step by step and with practical code snippets, how inputs → model outputs → loss → training come together for autoregressive LLMs (GPT style). It covers shapes, batching, indexing the correct probabilities, computing loss efficiently in PyTorch, and how to interpret perplexity.
Quick overview: the training objective for autoregressive LLMs
An autoregressive LLM is trained to predict the next token at each position. Given a sequence of tokensx = [x₁, x₂, ..., x_T]
the model provides at each position t a probability distribution P(y | x₁..x_t) over the vocabulary. Training reduces to maximizing the probability the model assigns to the true next tokens seen in data — equivalently, minimizing the average negative log probability of true next tokens across all positions and examples.
Put another way:
- For each input token position the model predicts a probability vector over the vocabulary (via logits → softmax).
- For each position we pick out the probability the model assigned to the correct target token.
- Take the negative log of those probabilities (giving a positive loss per token).
- Average across tokens (and batches) → scalar loss to backpropagate.
That scalar is the standard categorical cross-entropy / negative log likelihood used everywhere in LLM training.
Shapes and indexing — why dimensionality matters
It helps to keep dimensions explicit. Typical shapes:
batch_size = Bseq_len = T(context length; e.g., 256)vocab_size = V(GPT-2 token set: 50,257)
Model output (logits) has shape (B, T, V).
Each (V,) slice along the last axis is the (unnormalized) scores for the next token at that position.
Targets (the true next token IDs used for loss) have shape (B, T). Targets are essentially the input sequence shifted left by one position (the next-token prediction target).
When computing loss we usually flatten the first two dims into a single axis of B*T so we can compute cross-entropy in one go:
logits_flat.shape = (B*T, V)targets_flat.shape = (B*T,)
This flattening makes it straightforward to call library loss functions (e.g., torch.nn.functional.cross_entropy) which expect (*, V) logits and (*) integer targets.
Manual view of the computation of loss functions for llm (conceptual)
- Model returns
logitsof shape(B, T, V). - Convert logits → probabilities with softmax across
V. Each row gives a distribution over next tokens. - For each example position, select probability assigned to the true token
p = probs[b, t, target_id]. - Compute
-log(p)(negative log likelihood) for that position. - Average these values across all
(b,t)to produce scalar loss.
This is exactly what cross_entropy does under the hood (it combines the softmax and negative log likelihood in a numerically stable way).
Example with a toy vocabulary
Imagine a tiny vocabulary V = 7 and the model outputs, for a single sequence of length 3, these per-position probability rows:
pos 1: [0.10, 0.60, 0.20, 0.05, 0.00, 0.02, 0.01]
pos 2: [ ... ]
pos 3: [ ... ]
If the true next tokens at positions 1,2,3 are indices i1, i2, i3, we pick probabilities p1 = probs[0,i1], p2 = probs[1,i2], p3 = probs[2,i3]. Loss = -mean(log p1, log p2, log p3).
The training goal is to raise each p_k toward 1 (equivalently, maximize log p_k), so that the model assigns near-unit probability to correct next tokens.
Perplexity — an interpretable scalar
Perplexity is an alternative view of the same underlying loss, and it’s often used for LMs because it’s more directly interpretable.
Definition:
If L is the average negative log likelihood (in natural log base e) per token (the loss we computed), then
perplexity = exp(L)
Interpretation: perplexity is the effective branching factor — the number of equally likely choices the model behaves as if it is choosing between when predicting the next token. Lower is better.
- If
perplexity = 2→ model is about as uncertain as choosing between 2 equally likely tokens on average. - If
perplexity = V(vocab size) → model is nearly uniform across the whole vocabulary; essentially guessing at random.
Numerical example: if loss L = 10.79, then
perplexity = exp(10.79) ≈ 48,533
This would mean “the model is as uncertain as choosing from ≈48.5k tokens,” which is very high for a vocabulary ~50k — so the model is effectively guessing.
Perplexity is helpful because it maps the log-loss back to a quantity with an intuitive meaning related to vocabulary size.
Why cross-entropy is the right fit
- The model outputs a probability distribution over discrete tokens → we need a probability-based loss.
- Cross-entropy / negative log-likelihood directly penalizes low probability assigned to true data tokens.
- It chains neatly into gradient descent (it’s differentiable w.r.t. logits).
- It generalizes classification loss to the multi-class, per-position nature of language modeling.
Summary
- LLMs are trained by predicting the next token at every position. The standard objective is the negative log likelihood / categorical cross-entropy averaged over token positions and examples.
- In practice, compute
loss = F.cross_entropy(logits.view(-1, V), targets.view(-1))in PyTorch. - Perplexity = exp(loss) is an interpretable alternative: it tells you the effective number of choices the model is hedging between when predicting the next token.
- Watch tensor shapes carefully, mask padding, and use library loss functions for numeric stability.
References: https://www.youtube.com/watch?v=Zxf-34voZss&list=PLPTV0NXA_ZSgsLAr8YCgCwhPIJNNtexWu&index=29