An Overview of the LLM Training Pipeline
Large Language Models (LLMs) are transforming the modern world, in some ways exciting and unsettling. I’m writing a series of posts about them for learning, and as an experiment to explore the productivity boost from using AI. In this post, I’ll map out the process of training and deploying LLMs. I’ll be using diagrams and code to assist the learning process. I’ll start high level and go into depth.
High Level LLM Training and Deployment Process
The training of LLMs involves several main stages:
Stage | What happens | Key outputs / checkpoints |
---|---|---|
1. Raw Data Collection & Processing | Gather large‑scale text, code, and multimodal data ⇒ clean, deduplicate, filter toxic or private content, tokenize, and shard into training files. | Curated, versioned dataset + token statistics |
2. Foundation Pre‑training | Train a base transformer on the full corpus with next‑token prediction (or masked modeling) across thousands of GPU hours; periodically checkpoint and validate perplexity. | Foundational model checkpoints (billions of parameters) |
3. Supervised Fine‑Tuning (SFT) | Further train the base model on high‑quality, human‑written prompt‑response pairs to teach task formats, instruction following, and chain‑of‑thought style. | Instruction‑tuned weights + alignment eval scores |
4. RLHF / Alignment | Collect preference rankings or comparisons, train a reward model, then optimize the policy with PPO, DPO, or RLAIF to reduce harmful or unhelpful responses and improve UX. | Aligned model weights; reward‑model checkpoints |
5. Inference Service | Package the final model behind an efficient runtime (vLLM, TGI, TensorRT‑LLM), add batching & KV‑cache, expose streaming endpoints, autoscale in Kubernetes. | Production API endpoints, latency/throughput SLOs |
6. Monitoring & User Feedback | Log prompts, completions, costs, safety verdicts, and real‑time metrics; collect thumbs‑up/down, harvest new preference data; trigger rollback or retraining when drift is detected. | Telemetry dashboards, new labels feeding back into SFT / alignment loops |
Training in depth
A more detailed diagram of the LLM training process:
1. Pre-training
Concept | Why it matters | Typical choices / tips |
---|---|---|
Pre‑training objective | The model learns a general‑purpose prior by predicting the next token (causal LM) or filling masks (MLM) across huge corpora. | 99 % of large LLMs today use causal autoregressive loss with byte‑pair or sentencepiece tokens. |
Tokenizer & sequence packing | Converts raw text → IDs and assembles fixed‑length training examples without wasting context windows. | Train a BPE/Unigram tokenizer on the same corpus; use dynamic sequence packing so batches are ~99 % full. |
Model backbone | Defines parameter count, attention layout, positional encoding, etc. | GPT‑style decoder‑only transformer with FlashAttention 2, RoPE or ALiBi positions; optional SwiGLU activations. |
Optimizer & schedule | Handles huge batches (>4 M tokens) and learning‑rate stability. | AdamW or Lion with β₂ ≈ 0.95, grad‑clip 1.0, linear warm‑up → cosine decay; BF16 or FP16 + QLoRA for memory. |
Basic Training Loop
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
from tqdm import tqdm
def train_loop(
model: torch.nn.Module,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
num_epochs: int = 3,
learning_rate: float = 2e-4,
warmup_steps: int = 100,
max_grad_norm: float = 1.0,
device: str = "cuda"
):
"""Basic training loop with gradient accumulation and mixed precision."""
# Setup
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95))
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = torch.cuda.amp.GradScaler() # For mixed precision
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
for step, batch in enumerate(progress_bar):
# Move batch to device
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Forward pass with mixed precision
with torch.cuda.amp.autocast():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids
)
loss = outputs.loss
# Backward pass with gradient scaling
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# Optimizer step with gradient scaling
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Update learning rate
if step < warmup_steps:
lr_scale = min(1.0, float(step + 1) / float(warmup_steps))
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * learning_rate
else:
scheduler.step()
# Log metrics
total_loss += loss.item()
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
# Usage example:
if __name__ == "__main__":
# Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# Initialize wandb
wandb.init(project="llm-training", name="basic-training-loop")
# Create dataloaders (using the chunk_generator from previous example)
train_ds = chunk_generator(ds, seq_len=4096)
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True)
val_dataloader = DataLoader(train_ds, batch_size=1) # In practice, use a separate validation set
# Train
train_loop(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
num_epochs=3,
learning_rate=2e-4,
warmup_steps=100
)
More Advanced Training
Concept | Why it matters | Typical choices / tips |
---|---|---|
Distributed parallelism | Spreads the model and data across 100–10 000 GPUs. | Data Parallel + ZeRO Stage 3 (DeepSpeed) for most cases; add Tensor & Pipeline Parallel (Megatron‑LM) for >30 B params. |
Gradient accumulation | Virtual large batches without exceeding GPU RAM. | Accumulate 8–64 micro‑batches before an optimizer step; sync grads only at the step boundary. |
Mixed precision & kernel fusion | Doubles throughput and halves memory. | BF16 + FlashAttention, fused RMSNorm, rotary cache priming. |
Evaluation / early warning | Tracks quality and detects divergence. | Perplexity on held‑out shards every N steps; log with WandB / TensorBoard. |
Checkpointing & resumption | Protects days of GPU time from crashes; enables later SFT or RLHF. | Save model+optimizer+LR sched every 500–2 000 steps to S3/GCS; keep last 2 + every power‑of‑2 for time‑travel debugging. |
Supervised Fine-tuning
Concept | Why it matters | Practical notes |
---|---|---|
Instruction tuning | Teaches the foundation model to follow tasks and formats. | Mix narrow task data (e.g. SQL‑gen) with broad instruction sets; keep < 5 % of tokens but strong effect. |
LoRA / QLoRA | Adapter layers let you fine‑tune multi‑B‑param models on 1–4 GPUs. | Rank = 8–32, α ≈ 16; use 4‑bit GPTQ weights → 16× memory savings. |
Data curriculum | Over‑fitting to synthetic instructions hurts creativity. | Interleave human‑written (e.g. ShareGPT) with synthetic (Self‑Instruct) using Temp −1 sampling. |
Loss weighting | Certain tasks (e.g. JSON tools) deserve higher weight. | Group by “source”, apply sample‑level weights in the collate_fn . |
RLHF/Alignment
Concept | Role in pipeline | Implementation hints |
---|---|---|
Reward model (RM) | Approximates human preferences from ranked pairs. | Same backbone as policy; freeze limb norms; use pairwise log‑softmax loss. |
Policy optimisation | Improves helpfulness while controlling deviation from SFT. | PPO (OpenAI), DPO (Kim et al.), RLAIF (no RM). |
KL‑penalty / reference model | Keeps policy near SFT to avoid mode collapse. | Calculate KL(p‖p_ref) token‑wise; β ≈ 0.1–0.3. |
Safety tuning | Extra pass with refusal data, heuristics, jailbreak tests. | Can be applied as a reward shaping term or small SFT on refusal demonstrations. |
Serving/Inference
Pillar | Key idea | Tools / best practice |
---|---|---|
Runtime engine | Kernels optimised for KV‑cache & batch stitching. | vLLM (Pytorch‑FlashAttn, continuous batching); TensorRT‑LLM (CUDA Graphs); TGI (HF). |
Quantisation & MoE | Cut memory & cost with minimal quality loss. | GPTQ, AWQ, SmoothQuant for 4‑bit; vLLM now streams 4‑bit right off disk. |
Autoscaling | Align GPU count with QPS; bursty traffic. | KEDA + Prometheus custom metric (tokens_generated_total ). |
Observability | Structured logs = queries, latencies, token counts, costs. | OpenTelemetry traces; Loki + Grafana dashboards. |
Deep Dive into LLM Training
1 Data curation at trillion‑token scale
Volume & mixture. Meta’s Llama 31 pre‑trained on ≈15.6 trillion text tokens—an order‑of‑magnitude jump over Llama 2 (1.8 T) and similar to other 2025 frontier runs.
Filtering & deduplication. Frontier teams now apply multi‑stage “quality cascades”: aggressive near‑duplicate removal, per‑domain quality classifiers, heuristics for adult/hate content, and language balancing to avoid Anglo‑centric bias.
Mixture‑of‑sources. A typical recipe is ≈ 50‑60 % web crawl (CommonCrawl variants), 15–20 % curated corpora (books, papers, code), 10–15 % synthetic model‑generated text, and task‑specialised “gold” data (<1 %) used later for supervised fine‑tuning (SFT).
Packing & prefixing. Token‑level sequence packing (to minimise padding) and metadata prefixing (domain, language, license) are now standard to raise effective throughput by 15‑25 %.
3 Model architecture choices
Generation | Parameterisation | Context | Core design | Why it matters |
---|---|---|---|---|
Llama 3 | 8 B / 70 B / 405 B dense | up to 128 k | classic Transformer + minor rotary‑embedding tweaks | Dense is simpler to train & debug at 16 k GPU scale (ar5iv) |
Llama 4 | 109 B (SCOUT) / 400 B (MAVERICK) MoE (16 experts) | 1 mil | router + shared expert + SwiGLU blocks | Activates ≈10 % of params per token—better FLOP ↔ quality trade‑off (TechTalks) |
Optimiser & precision recipe
AdamW β1 = 0.9, β2 = 0.95, ε = 1e‑8 remains the default for stability.
LR schedule: 2 % warm‑up → cosine decay to 10 % of peak.
Mixed precision: BF16 for activations & gradients, FP8 (E4M3) for certain matmuls using FlashAttention‑3 kernels, giving 1.3–1.4× speed‑ups on Hopper GPUs .
Gradient clipping at 1.0; weight decay 0.1; dropout only in embeddings for long‑context models.
Post‑training alignment pipeline
Supervised fine‑tuning (SFT) on curated instruction‑response sets (1–5 M examples).
Rejection sampling to prune low‑quality generations.
Direct Preference Optimisation (DPO)—a KL‑regularised, pairwise‑ranking objective that’s simpler and more stable than PPO yet matches RLHF quality ar5iv .
Safety adapters like Llama Guard 3 or “red‑team” classifiers are attached as routing layers or post‑decoders.
8 Evaluation & safety gates Automated evals (MMLU, GSM‑8K, GPQA, CodeEval) every 1–2 B training tokens.
Human preference eval on 2 k–4 k prompts to monitor helpfulness/harmlessness.
Architecture Overview
There are many existing great posts that explain LLM model architectures. I include some key components here, with a future goal of going into depth:
1. Transformer Architecture
- Self-Attention Mechanism: Allows the model to weigh the importance of different words in a sequence
- Multi-Head Attention: Enables the model to focus on different parts of the sequence simultaneously
- Feed-Forward Networks: Process the attended information
- Layer Normalization: Helps stabilize training
- Residual Connections: Facilitate gradient flow during training
2. Model Components
- Embedding Layer: Converts input tokens into dense vectors
- Positional Encoding: Provides information about the position of tokens in the sequence
- Decoder/Encoder Blocks: Process the input through multiple layers of attention and feed-forward networks
Challenges and Considerations
- Computational Resources
- Large models require significant computational power
- Training can take weeks or months on specialized hardware
- Data Quality
- The quality of training data significantly impacts model performance
- Careful filtering and preprocessing are essential
- Ethical Considerations
- Bias in training data
- Potential for misuse
- Environmental impact of training large models
Future Directions
- Efficiency Improvements
- Model compression techniques
- More efficient architectures
- Better training algorithms
- Multimodal Capabilities
- Integration with vision and audio
- Cross-modal understanding
- Specialized Applications
- Domain-specific fine-tuning
- Customized solutions for specific industries
Conclusion
Understanding the architecture and training process of LLMs is crucial for both researchers and practitioners in the field of AI. As these models continue to evolve, they present both exciting opportunities and important challenges that need to be addressed.
This post provides a high-level overview of LLM architecture and training. For more detailed information, please refer to the original research papers and technical documentation.
-
https://ar5iv.labs.arxiv.org/html/2407.21783v1 ↩