Steve Kinney

The Attention Mechanism

Key Takeaways

  • Self-attention builds context via Q/K/V projections and scaled dot-product softmax
  • Multi-head attention captures diverse patterns; residuals and MLPs add depth
  • Causal masks enable generation; bidirectional masks power understanding tasks
  • Positional encodings (sinusoidal/learned/RoPE/ALiBi) inject order information
  • Attention is O(T²); use optimized kernels and long-context strategies when needed

Overview

Attention lets models weigh relationships between tokens to build context-aware representations. When you read “The cat sat on the mat,” your brain instantly knows that “sat” is related to “cat” (the one doing the sitting) and “mat” (where it sat). The attention mechanism gives models this same ability to understand relationships between words, no matter how far apart they are.

Notebook

View the companion notebook: Attention Mechanism

From Word Embeddings to Contextual Understanding

Before attention, each token gets mapped to a dense vector embedding that represents the word’s meaning in multi-dimensional space. The problem is that initially, the embedding for “bank” is identical whether in “river bank” or “money bank.” Attention solves this by taking these context-free embeddings and enriching them with information from all other words in the sentence.

The process starts with tokenization, where text becomes token IDs:

"The cat sat"[101, 1996, 4937, 3323, 102]

These IDs map to embeddings, but without context, they lack the nuance needed for understanding. This is where attention transforms static embeddings into dynamic, context-aware representations.

Queries, Keys, and Values

The attention mechanism creates three special vectors for each word’s embedding:

Query (Q) represents what the current word is looking for. Key (K) represents what each word has to offer. Value (V) contains the actual information each word provides.

The Library Analogy

Imagine researching “artificial intelligence” in a library. Your Query is your research topic. You go to the card catalog where each book’s title serves as a Key. You compare your query to each key to find matches. A book titled “A History of AI” strongly matches, while “Gardening” weakly matches. The book contents are the Values. You pull out the best-matching books (highest attention scores) and blend their information for a rich, contextual understanding.

The attention mechanism performs this process for every single word in the input, allowing the model to learn which other words are most important for understanding each word’s meaning in that specific context.

Visualizing Attention with Code

Let’s see attention in action using BERT and the bertviz library:

from transformers import AutoTokenizer, AutoModel
from bertviz import head_view

model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

sentence_a = "The cat sat on the mat"
sentence_b = "The dog played in the park"

# Tokenize and get model outputs
inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')
outputs = model(**inputs)
attention = outputs[-1]  # The attention scores

# Visualize the attention patterns
head_view(attention, tokens=tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

The visualization reveals strong connections between related words: cat connects to sat and mat, while dog connects to played and park. The model learns to keep the two sentences separate thanks to the [SEP] token.

Attention Mechanics

The core computation involves projecting inputs into Queries, Keys, and Values through learned linear transformations. The scaled dot-product attention formula is:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

The scaling factor √d_k prevents the dot products from growing too large, which would cause the softmax to saturate. After attention, outputs pass through residual connections and feed-forward networks for additional depth and nonlinearity.

Multi-Head Attention

Instead of computing attention once, multi-head attention runs several attention operations in parallel, each with different learned projections. This allows the model to jointly attend to information from different representation subspaces:

# Conceptually, for H heads:
for head in range(H):
    Q_h = Linear_Q[head](X)
    K_h = Linear_K[head](X)
    V_h = Linear_V[head](X)
    head_output[head] = Attention(Q_h, K_h, V_h)

output = Concat(head_outputs) @ W_O

Different heads learn different types of relationships. Some might focus on syntactic patterns, others on semantic relationships, and yet others on positional information.

Masking Strategies

Masking controls which tokens can attend to which other tokens:

Padding mask prevents attending to padded tokens in batched sequences. Causal mask ensures token t cannot see future tokens (t+1, t+2, …), essential for autoregressive decoders. Bidirectional attention allows tokens to attend freely in both directions, used in encoder models for understanding tasks.

The causal mask is implemented by setting future positions to negative infinity before the softmax:

# Create a causal mask
mask = torch.triu(torch.ones(seq_len, seq_len) * -inf, diagonal=1)
scores = (Q @ K.T) / sqrt(d_k) + mask
attention_weights = softmax(scores)

Positional Information

Pure self-attention is permutation-invariant, meaning it treats sequences as sets without order. Positional encodings inject order information:

Sinusoidal encodings use sine and cosine functions of different frequencies. Learned embeddings treat positions as learnable parameters. Rotary Position Embeddings (RoPE) rotate attention scores based on relative positions. ALiBi adds linear biases to attention scores based on distance.

For long sequences, strategies include sliding windows (attend only to nearby tokens), dilated patterns (skip tokens at regular intervals), and chunking with overlap for maintaining context.

Complexity and Performance

Vanilla attention has O(n²) complexity in sequence length, causing memory and time to grow quadratically. This becomes prohibitive for long sequences.

Optimizations include FlashAttention, which fuses operations and avoids materializing the full attention matrix. Sparse attention patterns reduce the number of token pairs considered. Linear attention approximations achieve O(n) complexity through kernel tricks.

When implementing, use mixed precision (fp16/bf16) and optimized kernels. Monitor peak memory usage and consider gradient checkpointing during training for longer sequences.

Shapes and Implementation

Given input sequence X of shape [B, T, D], linear projections produce Q, K, V with the same dimensions. For multi-head attention with H heads, the dimension D is split: [B, H, T, D_h] where D_h = D/H.

The attention computation preserves these shapes:

# Shapes during attention computation
Q: [B, H, T, D_h]
K: [B, H, T, D_h]
V: [B, H, T, D_h]
scores: [B, H, T, T]  # After Q @ K.T
weights: [B, H, T, T]  # After softmax
output: [B, H, T, D_h]  # After weights @ V
final: [B, T, D]  # After concatenation and projection

Causal vs Bidirectional vs Cross-Attention

Causal attention restricts each position to attend only to itself and past tokens, essential for decoder-only language models that generate text left-to-right.

Bidirectional attention allows each token to see the entire sequence, typical for encoder stacks in understanding tasks like BERT.

Cross-attention attends from one sequence (decoder) to another (encoder outputs), crucial in sequence-to-sequence models for translation or in multimodal models connecting text to images.

Inspecting and Debugging Attention

While attention maps provide some interpretability, they don’t fully explain model reasoning. However, visualization can reveal learned patterns:

Induction heads learn to copy or continue patterns. Positional heads focus on specific relative positions. Syntactic heads align with grammatical structures. Semantic heads connect related concepts.

Use attention visualization to debug issues with masking, positional encoding, or unexpected focus patterns. Tools like bertviz make this analysis accessible:

from bertviz import model_view

# Show attention across all layers and heads
model_view(attention, tokens, html_action='return')

Practical Considerations

When working with attention, balance model size with available memory. Start with pre-trained models that already have learned attention patterns. For custom implementations, carefully validate masking and use established libraries when possible.

The attention mechanism fundamentally changed how models process sequences, enabling them to capture long-range dependencies and complex relationships. Understanding attention is key to understanding modern transformer models and their remarkable capabilities in language understanding and generation.

Further Reading

Last modified on .