softmax(QK) gives you a probability matrix of shape [seq, seq]. Think of this like an adjacency matrix with edges with flow weights that are probabilities. Hence semantic routing of parts of X reduced with V.
where
- Q = X @ W_Q [query]
- K = X @ W_K [key]
- V = X @ V [value]
- X [input]
hence
attn_head_i = (softmax(Q@K/normalizing term) @ V)
Each head corresponds to a different concurrent routing system
The transformer just adds normalization and mlp feature learning parts around that.
where
- Q = X @ W_Q [query]
- K = X @ W_K [key]
- V = X @ V [value]
- X [input]
hence
attn_head_i = (softmax(Q@K/normalizing term) @ V)
Each head corresponds to a different concurrent routing system
The transformer just adds normalization and mlp feature learning parts around that.