Transformers & Vision Transformers: A Deep Technical Guide
1. Big Picture
Why Transformers Were Introduced
To understand why Transformers matter, you need to understand what they replaced and why those predecessors failed at scale.
Recurrent Neural Networks (RNNs / LSTMs / GRUs) process sequences token by token, maintaining a hidden state that is updated at each step. This creates two fundamental problems:
-
Sequential bottleneck. Token t cannot be processed until token t-1 is complete. This makes training inherently serial—you cannot parallelize across the sequence dimension. On modern GPUs, which derive their power from parallelism, this is devastating. Training a large RNN on long sequences underutilizes hardware by orders of magnitude.
-
Long-range dependency decay. Information from early tokens must survive repeated multiplicative transformations through the hidden state to influence predictions at later tokens. Despite gating mechanisms (LSTM’s forget/input/output gates), information still degrades over hundreds of steps. In practice, LSTMs struggle with dependencies beyond ~200–500 tokens.
Convolutional Neural Networks (CNNs) when applied to sequences (e.g., WaveNet, ConvS2S) offer parallelism but require stacking many layers to achieve a large receptive field. A 1D convolution with kernel size k sees k tokens; to see n tokens requires O(n/k) layers (or O(log n) with dilated convolutions). Long-range dependency is possible but architecturally expensive.
The core idea of the Transformer (“Attention Is All You Need,” Vaswani et al., 2017) is deceptively simple: let every token directly attend to every other token in a single operation. No sequential bottleneck, no stacking layers to grow a receptive field. One attention layer gives every token a global view of the entire sequence.
This single architectural decision—replacing recurrence with attention—unlocked three things simultaneously:
- Full parallelism during training (all positions computed in one matrix multiply)
- O(1) path length between any two tokens (direct attention, not through intermediaries)
- Content-based routing (attention weights are computed from the data, not fixed by architecture)
Why Transformers Became Dominant
Transformers dominate NLP (GPT, BERT, T5), vision (ViT, Swin), audio (Whisper), and multimodal systems (CLIP, Flamingo, GPT-4V) for a reason that goes beyond architectural elegance: they scale. The Transformer’s reliance on dense matrix multiplications maps perfectly onto GPU hardware. Doubling the model size or data yields predictable performance improvements (scaling laws). No other architecture family has demonstrated this property as consistently.
The counterpoint—which matters in practice—is that this dominance is partially contingent on having massive data and compute. For small-data regimes, CNNs (with their built-in spatial inductive biases) or even well-tuned RNNs can outperform Transformers. Dominance is a statement about the current resource landscape, not an absolute architectural superiority.
2. Transformer Fundamentals
The Transformer architecture consists of an encoder, a decoder, or both, depending on the task. Each is a stack of identical blocks. We will build up from primitives.
2.1 Input Representation
Tokenization converts raw text into a sequence of integer IDs. Modern systems use subword tokenizers (BPE, SentencePiece, WordPiece) that balance vocabulary size against sequence length. “Unbelievable” might become [“un”, “believ”, “able”]. This is an engineering decision with real consequences: smaller vocabularies mean longer sequences (more compute in O(n²) attention); larger vocabularies mean bigger embedding tables (more parameters, sparser updates).
Embedding maps each token ID to a dense vector in ℝ^d. This is a learnable lookup table—mathematically, a matrix E ∈ ℝ^{V×d} where V is vocabulary size and d is the model dimension. The embedding for token i is simply row i of E.
Positional Encoding is necessary because attention is permutation-equivariant: without positional information, the model cannot distinguish “dog bites man” from “man bites dog.” The original Transformer uses sinusoidal positional encodings:
PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
where pos is the position and i is the dimension index.
Why sinusoidal? Two reasons often cited:
- Relative position awareness. For any fixed offset k, PE(pos+k) can be expressed as a linear function of PE(pos). This means the model can learn to attend to relative positions through linear transformations.
- Extrapolation. Sinusoidal encodings generalize to sequence lengths unseen during training (in theory; in practice, this is limited).
Learned positional embeddings (used in BERT, GPT-2) are simply an additional learnable matrix P ∈ ℝ^{L×d} where L is the maximum sequence length. They tend to perform comparably to sinusoidal encodings for fixed-length training but cannot extrapolate beyond the training length.
Modern alternatives include Rotary Position Embeddings (RoPE, used in LLaMA, Qwen) and ALiBi, which encode relative position directly into the attention computation rather than adding it to the input. RoPE has become the de facto standard for autoregressive models due to its superior length extrapolation.
The final input to the Transformer is: x = TokenEmbedding(token) + PositionalEncoding(pos), a vector in ℝ^d for each position.
2.2 Self-Attention
This is the most important mechanism to understand deeply.
What problem it solves. Consider translating “The animal didn’t cross the street because it was too tired.” What does “it” refer to? A human instantly knows “it” = “the animal” because of semantic context. Self-attention lets the model build this kind of context-dependent representation: the representation of “it” can directly incorporate information from “animal” based on their semantic relationship.
Query / Key / Value Intuition. Think of self-attention as a soft dictionary lookup:
- Query (Q): “What am I looking for?” — the current token’s request for information
- Key (K): “What do I contain?” — each token’s advertisement of its content
- Value (V): “What information do I provide if selected?” — the actual content to retrieve
The analogy to a database is precise: Q is the search query, K is the index, V is the data. The difference is that this lookup is soft (weighted average) rather than hard (exact match).
Each is computed as a linear projection of the input:
Q = XW_Q, K = XW_K, V = XW_V
where X ∈ ℝ^{n×d} is the input sequence and W_Q, W_K, W_V ∈ ℝ^{d×d_k} are learned projection matrices.
Scaled Dot-Product Attention. The full computation:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Step by step:
-
QK^T ∈ ℝ^{n×n}: compute pairwise similarity between all queries and keys. Entry (i,j) measures how much token i wants to attend to token j. This is a dot product—tokens with aligned query and key vectors produce high scores.
-
/ √d_k: scale by the square root of the key dimension. This is not cosmetic—it is mathematically critical.
-
softmax(·): normalize each row to a probability distribution. Token i’s attention weights over all tokens sum to 1.
-
× V: weighted average of value vectors, using the attention weights. The output for token i is a weighted combination of all value vectors, where the weights reflect semantic relevance.
Why is the √d_k scaling needed? This is a high-frequency technical question. The reasoning:
Assume q and k are random vectors with entries drawn from a distribution with mean 0 and variance 1. Their dot product q·k = Σ(q_i × k_i) is a sum of d_k random variables, each with mean 0 and variance 1. By the CLT, the dot product has mean 0 and variance d_k. As d_k grows, the dot products grow in magnitude, pushing softmax inputs into regions where gradients are extremely small (the softmax saturates). Dividing by √d_k normalizes the variance back to 1, keeping the softmax in its sensitive region where gradients flow.
Concrete example: With d_k = 512, unscaled dot products can easily reach values of ±30–50. softmax([50, 1, 2]) ≈ [1.0, 0.0, 0.0]—essentially hard attention with near-zero gradients for non-maximum entries. After scaling by √512 ≈ 22.6, values shrink to ±1–2, and softmax produces smooth distributions that allow gradient-based learning.
2.3 Multi-Head Attention
Instead of performing a single attention function with d-dimensional keys, queries, and values, we split them into h heads, each operating on d_k = d/h dimensions:
head_i = Attention(XW_Q^i, XW_K^i, XW_V^i)
MultiHead(X) = Concat(head_1, ..., head_h) W_O
where W_O ∈ ℝ^{d×d} is the output projection.
Why multiple heads? A single attention head computes one set of attention weights per position—one “query” per token. But a token may need to attend to different tokens for different reasons simultaneously. In “The cat sat on the mat because it was soft,” the word “it” needs to attend to “mat” (for coreference) and to “soft” (for predicate compatibility) and to “sat on” (for syntactic role). Different heads can learn different attention patterns:
- Head 1: syntactic relationships (subject-verb, noun-adjective)
- Head 2: coreference (pronoun-antecedent)
- Head 3: positional proximity (local context)
- Head 4: semantic similarity (synonyms, related concepts)
Empirically, attention head visualizations confirm this specialization. Some heads attend to the previous token, some to syntactically related tokens, some to rare or informative tokens.
Compute cost: Multi-head attention has the same total computational cost as single-head attention with full dimensionality, because each head operates on d/h dimensions. The benefit is representational—multiple independent attention patterns—at no additional compute cost.
2.4 Feed-Forward Network
Each Transformer block contains a position-wise feed-forward network (FFN) applied independently to each token:
FFN(x) = W_2 · σ(W_1 · x + b_1) + b_2
where W_1 ∈ ℝ^{d×d_ff}, W_2 ∈ ℝ^{d_ff×d}, and σ is a nonlinear activation (ReLU in the original, GELU or SiLU/Swish in modern variants). Typically d_ff = 4d.
Why does the FFN exist? Attention is, fundamentally, a weighted averaging operation—it is linear with respect to the values. Without the FFN, stacking attention layers produces increasingly refined weighted averages but cannot compute arbitrary nonlinear functions. The FFN provides per-token nonlinear transformation, enabling the network to:
- Apply nonlinear feature transformations that attention alone cannot express
- Store and recall factual knowledge (recent interpretability work shows FFN layers act as key-value memories)
- Separate the “what to attend to” computation (attention) from the “what to do with the information” computation (FFN)
The FFN typically contains ~2/3 of the Transformer block’s parameters (due to the 4x expansion), making it the primary parameter reservoir. This is why MoE (Mixture of Experts) architectures replace the dense FFN with a sparse mixture—it is the highest-leverage component to sparsify.
2.5 Residual Connections + Layer Normalization
Residual connections add the input of each sub-layer to its output:
output = x + SubLayer(x)
This is critical for three reasons:
-
Gradient highways. In a deep network, gradients must flow backward through many layers. Each multiplicative transformation risks either vanishing or exploding gradients. The residual connection provides a direct additive path: ∂(x + f(x))/∂x = I + ∂f/∂x. Even if ∂f/∂x is small, the gradient is at least I (the identity). This is why 96-layer Transformers train successfully while 96-layer vanilla networks do not.
-
Iterative refinement. Each layer computes a correction to the current representation rather than a complete new representation. This is a more learnable formulation—the network only needs to learn the delta, not the full transformation.
-
Ensemble interpretation. Residual networks can be viewed as implicit ensembles of exponentially many shallower networks (Veit et al., 2016). Deleting individual layers has a smooth degradation effect rather than catastrophic failure.
Layer Normalization normalizes activations across the feature dimension:
LayerNorm(x) = γ ⊙ (x - μ) / (σ + ε) + β
where μ and σ are the mean and standard deviation computed across the d-dimensional feature vector for each token independently, and γ, β are learnable scale and shift parameters.
Pre-Norm vs. Post-Norm:
- Post-Norm (original Transformer):
output = LayerNorm(x + SubLayer(x)) - Pre-Norm (GPT-2, most modern models):
output = x + SubLayer(LayerNorm(x))
Pre-Norm is now standard because it produces more stable gradient magnitudes across layers, enabling training of very deep models (100+ layers) without careful learning rate tuning. The trade-off is that Pre-Norm can slightly underperform Post-Norm when Post-Norm successfully converges—but Post-Norm often requires learning rate warmup and careful initialization to converge at all.
3. Mathematical Deep Dive
Attention Formula — Full Derivation
Starting from first principles. Given an input sequence X ∈ ℝ^{n×d}:
Q = XW_Q ∈ ℝ^{n×d_k}
K = XW_K ∈ ℝ^{n×d_k}
V = XW_V ∈ ℝ^{n×d_v}
A = softmax(QK^T / √d_k) ∈ ℝ^{n×n} (attention weight matrix)
Output = AV ∈ ℝ^{n×d_v}
The attention weight matrix A has entry:
A_ij = exp(q_i · k_j / √d_k) / Σ_m exp(q_i · k_m / √d_k)
This is a softmax over the similarity scores between query i and all keys. The output for position i is:
output_i = Σ_j A_ij · v_j
A weighted average of all value vectors, where the weights are determined by query-key compatibility.
Time and Space Complexity
The dominant cost is computing QK^T:
- Time: O(n² · d_k) — for each of the n query vectors, compute a dot product with each of the n key vectors, each of dimension d_k.
- Space: O(n²) — the attention weight matrix A has n² entries.
This quadratic scaling is the Transformer’s central computational limitation. For n = 1024 (common in NLP), n² ≈ 1M — manageable. For n = 16384 (long documents), n² ≈ 268M — expensive. For n = 65536 (high-resolution images), n² ≈ 4.3B — infeasible without approximation.
The full forward pass of one Transformer block:
| Operation | Time Complexity | Space Complexity |
|---|---|---|
| Q, K, V projections | O(n · d · d_k) | O(n · d_k) |
| QK^T computation | O(n² · d_k) | O(n²) |
| Softmax | O(n²) | O(n²) |
| Attention × V | O(n² · d_v) | O(n · d_v) |
| FFN (both layers) | O(n · d · d_ff) | O(n · d_ff) |
| Total | O(n² · d + n · d · d_ff) | O(n² + n · d_ff) |
For typical settings (d_ff = 4d), the FFN cost O(n · d²) dominates for short sequences (n < d), while attention cost O(n² · d) dominates for long sequences (n > d).
Gradient Flow: Why Transformers Train Better Than RNNs
In an RNN, the gradient from time step t to time step t-k passes through k matrix multiplications:
∂h_t/∂h_{t-k} = Π_{i=t-k+1}^{t} ∂h_i/∂h_{i-1} = Π W_hh (simplified)
This product of matrices either vanishes (if spectral radius < 1) or explodes (if spectral radius > 1) exponentially with k. LSTMs mitigate this with additive gates, but the fundamental problem—long multiplicative chains—remains.
In a Transformer with residual connections, the gradient from layer L to layer l has a direct additive path:
x_L = x_l + Σ_{i=l}^{L-1} f_i(x_i)
∂x_L/∂x_l = I + Σ (terms involving ∂f_i/∂x_i)
The identity matrix I guarantees that gradients flow unimpeded regardless of depth. The additional terms allow gradient-based correction but cannot destroy the base gradient signal. This is why Transformers with 96+ layers train stably while deep RNNs require heroic engineering.
Why Attention Captures Long-Range Dependencies
The attention mechanism provides O(1) path length between any two tokens. In an RNN, information from token 1 must survive sequential processing through tokens 2, 3, …, n-1 to reach token n—an O(n) path. In a CNN with kernel size k, it requires O(n/k) layers. In attention, token n directly attends to token 1 in a single operation.
But path length is only half the story. The other half is content-based routing: attention weights are computed from the data, not hard-coded by architecture. This means the model dynamically decides which long-range connections matter for each input. An RNN must allocate fixed hidden state capacity to all possible dependencies; attention allocates capacity proportionally to relevance.
4. Vision Transformer (ViT)
4.1 Key Idea
The Vision Transformer (Dosovitskiy et al., 2020) asks: “What if we treat an image exactly like a sentence?” Instead of designing vision-specific architectures, ViT applies the standard Transformer encoder to images with minimal modification. The insight is that vision-specific inductive biases (locality, translation equivariance) may not be necessary—given sufficient data, a general-purpose architecture can learn them.
4.2 Patch Embedding
An image of size H×W×C is divided into a grid of non-overlapping patches of size P×P:
Number of patches: N = (H/P) × (W/P)
For a 224×224 image with P=16: N = 14×14 = 196 patches.
Each patch (P×P×C = 16×16×3 = 768 values) is flattened into a vector and linearly projected to the model dimension d:
z_i = flatten(patch_i) · E + e_pos_i, E ∈ ℝ^{(P²·C) × d}
This linear projection is equivalent to a convolution with kernel size P and stride P applied to the image—a useful conceptual bridge for CNN practitioners. The output is a sequence of N patch embeddings in ℝ^d, which are treated identically to word embeddings in NLP.
Why patches and not pixels? With 224×224 images, pixel-level tokenization produces 50,176 tokens. At O(n²) attention cost, this is computationally infeasible. Patches of 16×16 reduce the sequence to 196 tokens—a 256x reduction in attention cost. This is a pragmatic compression, not an optimal one: information is lost at patch boundaries, and the model cannot attend to sub-patch details in early layers.
4.3 Positional Encoding in Images
ViT uses learned 1D positional embeddings—one learnable vector per patch position, added to the patch embeddings. Despite the 2D spatial structure of images, 1D encodings work well because the model can learn spatial relationships from data.
This is initially surprising: shouldn’t a 2D-aware encoding be better? Experiments show that learned 1D and 2D positional encodings perform comparably when sufficient training data is available. The learned 1D embeddings, when visualized, do recover 2D spatial structure—nearby patches in 2D space have similar positional embeddings.
For higher resolution inference, positional embeddings trained at 224×224 must be interpolated (typically bicubic interpolation in the 2D grid) to handle, say, 384×384 images. This works reasonably well but is not perfect—a known limitation.
4.4 CLS Token and Classification
Following BERT, ViT prepends a learnable [CLS] token to the sequence:
Input: [CLS, patch_1, patch_2, ..., patch_N]
After processing through L Transformer blocks, the [CLS] token’s output representation is used as the aggregate image representation, fed to a classification head (MLP):
y = MLP(LayerNorm(z_L^0))
where z_L^0 is the [CLS] token output at layer L.
Why a CLS token instead of global average pooling? The CLS token acts as a learnable aggregation query—it attends to all patches and learns to extract task-relevant information. Global average pooling (averaging all patch representations) works comparably in practice and is simpler. Many subsequent works (DeiT, BEiT) use pooling instead. The choice is empirical, not principled.
4.5 ViT vs. CNN Comparison
| Aspect | CNN | ViT |
|---|---|---|
| Inductive bias | Strong: locality (convolutions see local neighborhoods), translation equivariance (same filter everywhere), hierarchical structure (progressively larger receptive fields) | Weak: only sequence ordering via positional encoding. No built-in locality or translation equivariance |
| Data efficiency | High: inductive biases act as strong priors, enabling learning from smaller datasets. A ResNet-50 trains well on ImageNet (1.3M images) | Low: without inductive biases, ViT must learn spatial relationships from data. Requires ~14–300M images (JFT-300M) to match CNN performance. Below this, ViTs underperform CNNs |
| Scaling behavior | Saturates: performance plateaus beyond a certain model/data size. ResNets scale poorly beyond ~1B parameters | Continues improving: ViT performance improves log-linearly with data and compute. At sufficient scale, ViT surpasses CNNs and does not plateau as early |
| Receptive field | Grows gradually with depth. Global context requires very deep networks or dilated convolutions | Global from layer 1. Every patch can attend to every other patch in the first layer |
| Compute pattern | Efficient on small inputs (convolutions are fast). Scales linearly with image area for each layer | O(n²) attention on patch count. Expensive for high-resolution images unless optimized |
| Transfer learning | Strong for similar domains. Representations are somewhat domain-specific due to convolutional structure | Excellent for diverse downstream tasks when pretrained at scale, due to flexibility of learned representations |
| Interpretability | Feature maps correspond to spatial locations, amenable to visualization (Grad-CAM, etc.) | Attention maps provide some interpretability but are less directly spatial. Attention ≠ explanation (see Section 9) |
The key takeaway: ViT is not strictly better than CNNs. It is better at scale. The crossover point—where ViT begins to outperform CNNs—depends on data volume. For small datasets or resource-constrained settings, CNNs (with their built-in priors) remain competitive or superior. This is a nuanced point that experienced teams often stress.
5. Engineering Perspective
Building Transformers in PyTorch
A minimal but complete self-attention implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, N, D = x.shape
qkv = self.qkv_proj(x) # (B, N, 3D)
qkv = qkv.reshape(B, N, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, d_k)
q, k, v = qkv.unbind(0) # each (B, H, N, d_k)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k) # (B, H, N, N)
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v # (B, H, N, d_k)
out = out.transpose(1, 2).reshape(B, N, D) # (B, N, D)
return self.out_proj(out)
Key implementation details that matter in practice:
-
Fused QKV projection. A single linear layer for Q, K, V is faster than three separate layers—one matrix multiply instead of three, reducing kernel launch overhead and memory accesses.
-
Reshape for multi-head. The (B, N, D) → (B, H, N, d_k) reshape is a view operation (no memory copy), but the subsequent permute requires non-contiguous memory access. Understanding memory layout matters for performance.
-
Masking. For causal (autoregressive) attention, a triangular mask prevents attending to future tokens. The mask is applied before softmax by setting masked positions to -∞, so softmax produces 0 for those positions.
Memory Bottlenecks
The dominant memory consumers during training:
-
Attention weight matrix: O(B × H × N² × sizeof(float)). For B=32, H=12, N=2048: 32 × 12 × 2048² × 4 bytes ≈ 6.4 GB. This is the primary reason long sequences are expensive.
-
Activations stored for backpropagation. Every intermediate tensor (after each attention, FFN, norm) must be cached for the backward pass. A 24-layer Transformer stores 24× more activations than a single layer.
-
Optimizer states. Adam stores first and second moment estimates for every parameter—3x the parameter memory. For a 7B parameter model in FP32: 7B × 4 bytes × 3 = 84 GB of optimizer state alone.
Training Tricks
Learning Rate Scheduling. Transformers are notoriously sensitive to learning rate. The standard recipe:
-
Warmup: Linearly increase learning rate from 0 to peak over the first T_warmup steps (typically 1–5% of training). This is necessary because early gradients are unreliable—the model’s attention patterns are essentially random at initialization, producing high-variance gradients. A large learning rate on random gradients can permanently damage training.
-
Cosine decay: After warmup, decay the learning rate following a cosine schedule to near-zero. This provides aggressive learning early (when the loss landscape has clear gradients) and fine-grained optimization late (when the model is near a minimum).
if step < warmup_steps:
lr = peak_lr * (step / warmup_steps)
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
lr = peak_lr * 0.5 * (1 + cos(π * progress))
Batch Size Effects. Transformers benefit from large batch sizes more than most architectures. The attention mechanism’s O(n²) cost means that larger batches improve GPU utilization (better amortization of kernel launch overhead). Large-batch training also produces more stable gradient estimates, which is important for attention layers where gradients can be noisy. Typical batch sizes for large Transformers: 256–4096 sequences (achieved via gradient accumulation across multiple GPUs).
Optimization Techniques
Mixed Precision Training (FP16/BF16). Store model parameters in FP32 (master copy) but perform forward and backward passes in FP16 or BF16:
- 2x memory reduction for activations
- 2–3x faster matrix multiplications on tensor cores
- Loss scaling prevents FP16 underflow (multiply loss by a large factor before backward, divide gradients by the same factor after)
- BF16 (available on A100+) is preferred over FP16: same range as FP32 (no underflow), slightly less precision, no loss scaling needed
Gradient Checkpointing. Trade compute for memory: instead of storing all intermediate activations, store only at selected layers (“checkpoints”) and recompute the others during backward pass. Reduces activation memory from O(L × N × d) to O(√L × N × d) at the cost of ~33% more compute. This is essential for training large models on limited GPU memory.
Flash Attention. Fuses the softmax(QK^T)V computation into a single GPU kernel, avoiding materialization of the N×N attention matrix in HBM (high-bandwidth memory). Reduces memory from O(N²) to O(N) and achieves 2–4x wall-clock speedup by exploiting the GPU memory hierarchy (SRAM vs HBM). This is now the standard attention implementation (PyTorch 2.0+ includes it via torch.nn.functional.scaled_dot_product_attention).
6. Common Variants and Improvements
BERT vs. GPT: Encoder vs. Decoder
This is a fundamental architectural distinction that practitioners must understand clearly.
BERT (Encoder-only). Uses bidirectional self-attention—every token attends to every other token, including those to its right. Trained with masked language modeling (predict randomly masked tokens from context). Produces contextualized representations for each token. Used for classification, NER, retrieval, and understanding tasks. Cannot generate text autoregressively.
GPT (Decoder-only). Uses causal (masked) self-attention—each token can only attend to itself and previous tokens (the upper triangle of the attention matrix is masked to -∞). Trained with next-token prediction. Generates text autoregressively (one token at a time, left to right). The architecture behind GPT-4, Claude, Llama, and all modern LLMs.
Why did decoder-only win? The scaling advantage. Next-token prediction provides a training signal at every position (n signals per sequence of length n), scales trivially to any text data, and naturally supports generation. Masked language modeling wastes ~85% of positions (only 15% are masked) and requires task-specific fine-tuning for generation. At scale, the simplicity and data efficiency of next-token prediction dominates.
Encoder-decoder (T5, original Transformer). Encoder processes input with bidirectional attention; decoder generates output autoregressively while cross-attending to encoder representations. Natural for sequence-to-sequence tasks (translation, summarization). Less popular at the frontier but still used in specialized systems (Whisper for speech, Flan-T5 for instruction following).
Swin Transformer
Swin addresses ViT’s two weaknesses for vision: quadratic cost and lack of multi-scale features.
Shifted window attention: Instead of global attention over all patches, Swin restricts attention to local windows (e.g., 7×7 patches). This reduces cost from O(N²) to O(N × W²) where W is the window size—linear in image size. To enable cross-window information flow, the windows are shifted by half the window size in alternate layers.
Hierarchical structure: Like CNNs, Swin progressively merges patches (2×2 patch merging = 2x spatial downsampling) to create multi-scale feature maps. This produces the pyramid-shaped feature hierarchy that object detection and segmentation architectures (FPN, Feature Pyramid Network) expect, making Swin a drop-in replacement for CNN backbones.
Swin is the dominant vision Transformer for dense prediction tasks (detection, segmentation) where multi-scale features and linear cost are requirements.
Efficient Attention Mechanisms
Linformer approximates the N×N attention matrix by projecting K and V to lower dimensionality: K’ = K·P_K where P_K ∈ ℝ^{N×k} for k « N. This reduces complexity to O(N·k) but introduces approximation error and cannot be used for autoregressive generation (the projection depends on the full sequence).
FlashAttention (discussed in Section 5) is not an approximation—it computes exact attention but reorganizes the computation to minimize memory transfers. It has become the default implementation, making many approximate attention methods obsolete for practical purposes.
Sparse attention (BigBird, Longformer) combines local window attention with global tokens and random attention patterns to achieve O(N√N) or O(N·log N) complexity while maintaining reasonable approximation quality.
Multimodal Models (Brief)
CLIP (Contrastive Language-Image Pretraining) trains a vision encoder and a text encoder jointly using contrastive learning on image-text pairs. It creates aligned embedding spaces where images and text can be directly compared. This is the foundation for zero-shot image classification, image retrieval, and multimodal model architectures.
Modern multimodal models (GPT-4V, Gemini, Qwen-VL) typically use a vision encoder (ViT-based) to produce visual tokens, which are then interleaved with or projected into the language model’s token sequence. The Transformer processes both modalities jointly, enabling cross-modal reasoning. The architectural bet is that a single powerful sequence model can handle heterogeneous modalities if the tokenization is right.
7. Technical Questions — With Answers
Basic
Q: Why use attention instead of RNNs?
Three concrete reasons. First, parallelism: attention processes all positions simultaneously in a single matrix multiply, while RNNs must process sequentially—this gives 10–100x training speedup on GPUs. Second, long-range dependency: attention provides O(1) path length between any two tokens vs. O(n) for RNNs, which means information does not decay over distance. Third, gradient flow: residual connections around attention layers provide direct gradient pathways, avoiding the vanishing/exploding gradient problem that plagues deep RNNs. The trade-off is O(n²) cost vs. O(n) for RNNs, which matters for very long sequences.
Q: What is positional encoding and why is it needed?
Attention is permutation-equivariant—it treats the input as a set, not a sequence. Without positional information, “dog bites man” and “man bites dog” produce identical representations. Positional encoding injects ordering information by adding position-dependent vectors to token embeddings. The original Transformer uses sinusoidal functions with different frequencies, which provide a unique encoding per position and allow the model to learn relative position through linear transformations. Modern models often use learned embeddings (BERT) or rotary positional embeddings (RoPE, used in LLaMA) that encode relative position directly in the attention computation.
Intermediate
Q: Why multi-head attention instead of single-head?
A single attention head computes one set of attention weights—one “query” per token position. But tokens need to capture multiple types of relationships simultaneously (syntactic, semantic, positional, coreference). Multi-head attention runs h parallel attention computations, each in a d/h-dimensional subspace, allowing different heads to specialize in different relationship types. The total compute cost is identical (h heads of dimension d/h = one head of dimension d), but the representational capacity is richer. Empirically, individual heads learn interpretably different patterns: some attend locally, some attend to syntactic dependencies, some attend to rare tokens.
Q: Why divide by √d_k in attention?
Without scaling, dot products between query and key vectors have variance proportional to d_k (by the central limit theorem, the sum of d_k products of unit-variance terms has variance d_k). As d_k grows (typical values: 64–128), dot products become large in magnitude, pushing softmax inputs into saturation where the output is approximately one-hot and gradients approach zero. Dividing by √d_k normalizes variance to 1, keeping softmax in its informative (non-saturated) regime. The practical consequence: without scaling, Transformers train extremely slowly or not at all for large d_k values.
Q: Why LayerNorm instead of BatchNorm?
BatchNorm normalizes across the batch dimension—it computes statistics from all examples in a mini-batch for each feature. This creates two problems for Transformers. First, sequence lengths vary across examples, making batch statistics unreliable for sequence models. Second, during autoregressive inference, batch size is typically 1, and accumulating running statistics from training is a poor approximation for variable-length generation. LayerNorm normalizes across the feature dimension within each individual example and position, avoiding both issues. It is also independent of batch size, making it compatible with any inference setting.
Advanced
Q: How would you reduce the O(n²) complexity of attention?
Several approaches, ranked by practical impact:
-
FlashAttention (most impactful). Does not reduce theoretical complexity but eliminates the O(n²) memory bottleneck by tiling the attention computation to fit in GPU SRAM. Achieves exact attention with 2–4x wall-clock speedup. This is the first thing to try.
-
Windowed/local attention (Swin, Longformer). Restrict attention to a local window of size w, giving O(n·w) complexity. Combined with shifted windows or sparse global tokens to maintain global information flow. Well-suited for vision (spatial locality) and long documents.
-
Linear attention (Linformer, Performer, RWKV). Replace softmax(QK^T)V with kernel approximations or recurrent formulations that avoid materializing the n×n matrix. Reduces to O(n·d) but sacrifices some quality. RWKV and Mamba (state-space models) are promising alternatives that achieve O(n) complexity with competitive quality.
-
Sparse attention (BigBird). Combine local, global, and random attention patterns. O(n·√n) typical. Theoretically elegant but often slower than FlashAttention on real hardware due to irregular memory access patterns.
The practical insight: algorithmic complexity and wall-clock time are different things. FlashAttention is O(n²) but faster than many O(n) methods in practice because it exploits hardware memory hierarchy.
Q: How would you adapt a Transformer for very long sequences (100K+ tokens)?
This is a systems question, not just an algorithms question.
Model architecture: Use RoPE or ALiBi for positional encoding (both extrapolate to unseen lengths better than learned embeddings). Apply FlashAttention-2 or ring attention for memory-efficient computation. Consider sliding window attention for layers that do not need global context.
Training: Train at shorter context (e.g., 4K tokens) and fine-tune at longer context with continued pretraining (this is how Llama 3 extends to 128K). Use YaRN (Yet another RoPE extensioN) or dynamic NTK scaling to extend RoPE beyond training lengths.
Inference: Implement KV-cache for autoregressive generation (store key-value pairs from previous tokens to avoid recomputation). For very long contexts, use PagedAttention (vLLM) to manage KV-cache memory efficiently with paging, avoiding fragmentation. Consider GQA (Grouped Query Attention) which shares key-value heads across query heads, reducing KV-cache size by 4–8x with minimal quality loss.
Q: How do you debug a Transformer that does not converge?
Systematic checklist:
-
Check learning rate. Too high → loss oscillates or diverges. Too low → loss decreases very slowly. Try the classic sweep: 1e-5, 3e-5, 1e-4, 3e-4, 1e-3. Transformers often need warmup—if you skipped it, add it.
-
Check loss computation. Is the loss masked correctly? For language models, are padding tokens excluded from the loss? For masked LM, are only masked positions contributing? Off-by-one errors in masking are common and devastating.
-
Check attention masks. For causal models, is the causal mask correct? A bug here means the model can “cheat” by seeing future tokens, producing artificially low training loss but garbage at inference.
-
Check data. Are tokens correctly encoded/decoded? Print actual inputs and targets. Shuffling errors, truncation bugs, and encoding mismatches are common.
-
Check gradient norms. Use
torch.nn.utils.clip_grad_norm_and log the pre-clip gradient norm. Spikes indicate instability (often from a few bad examples or learning rate too high). Consistently near-zero indicates vanishing gradients (check residual connections, initialization). -
Verify with tiny experiment. Overfit on a single batch of 8 examples. If the model cannot memorize 8 examples, the architecture or training loop has a bug. If it can, the issue is data/scale/hyperparameters.
-
Check numerical stability. NaN/Inf in loss often comes from log(0) in cross-entropy (add epsilon), underflow in FP16 softmax (use BF16 or loss scaling), or division by zero in LayerNorm (check epsilon).
8. System Design Perspective
Recommendation Systems
Modern recommendation systems (YouTube, TikTok, Amazon) increasingly use Transformer-based sequence models to capture user behavior patterns. The user’s interaction history (clicks, views, purchases) is treated as a sequence of item tokens, and a Transformer encodes this sequence to predict the next interaction.
Why Transformers here? User behavior has long-range dependencies (a user who watched cooking videos three weeks ago may return to that interest) and complex multi-factor patterns (combining time-of-day, recent behavior, and long-term preferences). Attention naturally captures these patterns without hand-engineered feature crosses.
Engineering considerations: Inference latency is critical (recommendations must be computed in <50ms for real-time serving). Solutions include precomputing user embeddings offline, using small (2–4 layer) Transformers, and caching KV states for incremental updates as new interactions arrive.
Retrieval and RAG (Retrieval-Augmented Generation)
The typical RAG pipeline:
- Embedding model (Transformer-based, e.g., E5, BGE, GTE) encodes documents into dense vectors
- Vector database (Pinecone, Weaviate, pgvector) stores and indexes these embeddings
- Retrieval finds the k most similar documents to a query embedding
- Generation model (LLM) produces an answer conditioned on the retrieved context
The Transformer plays two roles: as the encoder that creates representations (understanding), and as the generator that produces output (reasoning). The embedding model is typically a BERT-style encoder fine-tuned with contrastive learning; the generator is a GPT-style decoder.
Design trade-offs: Larger embedding models produce better retrieval but are slower to encode. Larger chunk sizes preserve context but reduce retrieval precision. More retrieved chunks improve recall but increase the generator’s context length (and cost). The system design question is about balancing these trade-offs for a specific latency and cost budget.
ViT vs. CNN in Production Vision Systems
ViT is preferred when: large pretrained models are available (leveraging transfer learning), the task benefits from global context (scene understanding, document analysis), and inference can tolerate higher latency.
CNNs are preferred when: latency is critical (mobile, edge, real-time video), the dataset is small (CNNs’ inductive biases help), and the deployment target has limited compute (CNNs are more parameter-efficient for simple tasks).
Hybrid architectures (EfficientFormer, MobileViT, FastViT) combine CNN stems (efficient local feature extraction) with Transformer blocks (global reasoning), aiming for the best of both worlds. These are increasingly common in production settings where both efficiency and capability matter.
9. Common Pitfalls and Misconceptions
“Attention weights explain model decisions”
This is one of the most persistent and dangerous misconceptions. Attention weights show where the model “looks” but not what it “thinks.” Multiple studies (Jain & Wallace, 2019; Wiegreffe & Pinter, 2019) demonstrate that:
- Different attention patterns can produce identical outputs
- Adversarially modified attention weights that look nothing like the originals produce similar predictions
- Attention weights are not gradients—they do not measure feature importance
Attention maps are at best a loose heuristic for interpretability, not a faithful explanation. For actual feature attribution, use gradient-based methods (integrated gradients, attention rollout) or perturbation-based approaches.
“Transformers are always better than CNNs”
False in several important regimes:
- Small data. With <10K training images, a ResNet-50 typically outperforms a ViT-Base. CNNs’ built-in locality and translation equivariance provide strong priors that compensate for limited data.
- Low latency. MobileNetV3 or EfficientNet at 5ms inference vs. ViT-Tiny at 15ms—for real-time mobile applications, the 3x latency difference matters.
- Simple tasks. For tasks where local features suffice (edge detection, texture classification), CNNs are more parameter-efficient.
- Resource-constrained deployment. CNN inference is highly optimized across all hardware (CPUs, GPUs, NPUs, DSPs). Transformer inference optimization is improving (FlashAttention, TensorRT) but not yet as universal.
“More parameters = better model”
The relationship between parameter count and performance is mediated by training data, compute, and architecture efficiency. Key counterexamples:
- A well-tuned 7B model (Mistral 7B) outperforms a poorly tuned 13B model on many tasks
- MoE models (Mixtral 8x7B = 47B total, 13B active) outperform dense models with the same active parameter count because they use total parameters as a memory bank while keeping compute constant
- Distilled models (student models trained to mimic larger teachers) regularly achieve 90%+ of teacher performance at 10–20% the size
- The scaling laws show that optimal performance for a fixed compute budget requires specific model-size-to-data-size ratios. Scaling parameters without proportionally scaling data is wasteful.
The nuanced view: parameter count is a resource to be allocated optimally, not a metric to maximize.
10. Synthesis
Why Transformers Are a Paradigm Shift
The Transformer is not just another architecture—it is a computational primitive that has proven effective across nearly every data modality. This universality is unprecedented. Previous architectures were domain-specific: RNNs for sequences, CNNs for images, GNNs for graphs. The Transformer handles all of these, plus audio, video, protein sequences, molecular structures, and robotic action sequences. This universality enables transfer learning across modalities, unified multimodal architectures, and shared infrastructure investments.
The deeper reason for this universality: attention is a general-purpose information routing mechanism. Given any collection of entities (tokens, patches, graph nodes), attention learns which entities should communicate and what information should flow between them. This is a computation pattern that transcends specific domains.
When NOT to Use Transformers
Transformers are the wrong choice when:
- The sequence is very long and low-bandwidth. Time-series with millions of steps and simple patterns (sensor monitoring, signal processing) are better served by state-space models (Mamba, S4) or even classical methods (Kalman filters, wavelets).
- Data is tiny. With hundreds of training examples, a Random Forest or gradient-boosted tree will outperform any Transformer. Inductive biases from domain knowledge beat learned representations when data is scarce.
- Latency budget is microseconds. Transformers have high per-token cost. For real-time control systems (robotic motor control at 1kHz), simpler models with guaranteed latency are necessary.
- The problem has strong known structure. If you know the underlying physics (fluid dynamics, molecular dynamics), physics-informed architectures outperform general-purpose Transformers that must discover the structure from data.
- Interpretability is legally required. Attention weights do not constitute explanations. For regulated domains requiring auditable decision-making, linear models or decision trees may be necessary.
What Skills an MLE Should Demonstrate
When discussing Transformers in depth, the strongest practitioners demonstrate:
-
Layered understanding. They can explain attention at the intuition level (“soft dictionary lookup”), the mathematical level (scaled dot-product, gradient flow), and the implementation level (memory layout, Flash Attention, KV-cache). They move between these levels fluidly based on follow-up questions.
-
Trade-off awareness. They never say “Transformers are the best” without qualifying when and why. They understand the data efficiency trade-off (CNNs vs. ViT), the compute trade-off (O(n²) vs. efficient alternatives), and the complexity vs. performance trade-off (simple baselines vs. complex architectures).
-
Systems thinking. They connect model architecture to deployment reality: how attention complexity affects inference cost, how KV-cache size determines maximum batch size during serving, how model size determines whether cloud or edge deployment is feasible.
-
Debugging intuition. They can diagnose training failures systematically rather than guessing: checking learning rate, verifying masking, monitoring gradient norms, isolating model vs. data issues with single-batch overfitting.
-
Awareness of the frontier. They know about current developments (FlashAttention, GQA, MoE, RoPE, state-space models as alternatives) and can discuss them substantively, not just name-drop. They understand that the field moves fast and that specific implementation details change, but the principles (parallelism, gradient flow, inductive bias trade-offs) are durable.
The Transformer is the most important architecture for any ML engineer to understand deeply—not because it is perfect, but because it is the foundation on which virtually all modern AI systems are built, and understanding its strengths, limitations, and trade-offs is the clearest signal of engineering maturity in the field.
Enjoy Reading This Article?
Here are some more articles you might like to read next: