Understanding Attention Mechanisms
Before 2017, sequence models relied heavily on recurrence. Then came "Attention Is All You Need," and the landscape of NLP fundamentally shifted.
At its core, self-attention allows a model to look at other words in the input sequence as it processes a specific word, capturing context in a mathematically elegant way that bypasses the bottlenecks of sequential processing.
The Mathematics of Self-Attention
The intuition is simple: given a query, how relevant are the available keys? The dot product between a query vector and a key vector gives us a similarity score. We scale this by the square root of the key dimension to prevent vanishing gradients during softmax.
Formula: Attention(Q, K, V) = softmax(QK^T / √d_k)V
Implementation in PyTorch
Building this from scratch strips away the magic and reveals the elegant linear algebra underneath.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super().__init__()
self.d_k = d_k
def forward(self, q, k, v, mask=None):
# q, k, v shape: (batch_size, seq_len, hidden_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
return output, attention