In this post I will walk through the transformer layer and several improvements over this architecture that are commonly employed in many popular open source large language models (LLMs) today, for example Llama. Discussed include SwiGLU and RMSNorm layers, RoPE and ALiBi position embeddings; and finally Flash Attention for scaling attention calculation to long sequences. We will use Llama source code as example implementation, and toward the end I’ll go through the rest of Llama’s source code.
Transformers
A Transformer block consists of an attention layer, two normalization layers, residual connections and a feedforward layer. Many models like Llama are usually made of 32 transformer blocks stacked together, plus normalization and linear projection at the end. A transformer layer is a “cube”, meaning that the input and the output have the same shape: N vectors in, N vectors out.
For language modeling (next-token prediction), each output vector $i=1,\ldots,N$ will be mapped to an unnormalized logit vector with length vocab_size
that represents the predicted distribution over the next token, given the attention with previous tokens. During training, the logit for ground truth will be maximized for each token in the sentence. For example, with “I eat apple with happiness”, we will maximize “eat” given “I”, maximize “apple” given “I eat”, etc. In inference though, only the final vector ("out[-1]
") in the output tensor is needed, and we map the final output vector to logits to predict the next token.
Among transformer’s components, the attention layer is the most important one.
The attention layer
As with a transformer layer, we can also think of an attention layer as a cube, because the output of an attention layer has the same shape as its input: it takes a sequence of embedding vectors as input, and output an updated version of those input embeddings, where each updated vector now attends to all other vectors in the sequence (including itself), or in the case of autoregressive modeling, attends to all previous vectors in the sequence (including itself).
An attention layer has three matrices as parameters: query matrix, key matrix, and value matrix, each has shape $d\times d$, where $d$ is the token embedding dimension. Let’s take $d=2$, so that we can see the essence more clearly. In this case, we have three $2\times2$ matrices, for a total of 12 parameters. We have:
- the 2x2 query matrix $Q$ transforms any input vector $v\in\mathbb{R}^2$ to a query vector $q_v\in \mathbb{R}^2$.
- the 2x2 key matrix $K$ transforms any input vector to a key vector in $\mathbb{R}^2$.
- the 2x2 query matrix $V$ transforms any input vector to a value vector in $\mathbb{R}^2$.
Let’s walk through the attention calculation. The input is a sequence of $N$ tokens, each embedded as a 2d vector. The input tensor is then an $N\times2$ matrix. Apply the three matrix multiplications, we get
- Nx2 query tensor,
- Nx2 key tensor,
- Nx2 value tensor,
where each row is a 2d embedding vector.
For each vector, we want to update itself as a weighted sum of all the vectors in the sequence (including itself), weighted by cosine similarity (dot product).
$$v = \sum\alpha_jv_j$$ $$\alpha_j =v_i \cdot v_j \quad\Rightarrow\quad \alpha_j=e^{v_i \cdot v_j}/\sum e^{v_i \cdot v_j}\,\text{(normalize)}$$
The query tensor (Nx2) times the transpose of the key tensor (2xN) is an NxN matrix of dot products between each pair of vector in the sequence of length N.
Each row is a list of similarity measures between the token $i$ with all other tokens in the sequence. If we don’t want token $i$ to attend to subsequent tokens, then we change those similarity scores to $-\infty$, since $e^{-\infty} = 0$. Otherwise, as in BERT, we leave them unchanged. Applying softmax row-wise, we get normalized weights for each token $i$.
Note that it is in this sense that the attention layer can handle sequences of arbitrary length: the N just gets larger, the score matrix gets larger (grow O(N^2)), though there are only three 2x2 parameter matrices Q, K and V for a total of 12 parameters.
Note that softmax per se is not a linear operation. For an input vector $v$, $\mathrm{sm}(v)$ apply $e^{(\cdot)}$ to each element in $v$ and then normalize them by the sum of the exponentials. It is not true that $\mathrm{sm}(v + w) = \mathrm{sm}(v) + \mathrm{sm}(w)$.
After that, we just multiply the score matrix with the value tensor. Each row $i$ in the output is then an updated version of token $i$’s’ embedding.
So this is the vanilla attention. In practice there are several additional variants:
- Since we multiply an $N\times N$ square matrix with the $N\times d_v$ value tensor, the dimension of the value vectors $d_v$ can be different, i.e. they don’t have to be 2, but can be 3, 4, or whatever. This means the value matrix $V$ can have shape $2\times d_v$, that maps each 2d token embedding to a $d_v$-dimensional vector. The queriy and key vectors though, should have the same dimension for dot product.
- The attention layer we just illustrated is called single head attention, with one 2x2 key matrix, one 2x2 query matrix and one 2x2 value matrix as parameters. Similar to mapping an image with 3 channels to multiple channels with multiple kernels in convolutional layer in CNN, one may want to use multiple Q, K and V matrices to perform the attention calculation multiple times and concatenate the output from multiple calculations, in order to capture different aspects of an input sequence. This is called “multi-head attention”.
- It is also called self-attention because each token attends to other tokens within the same sequence. If, for example, the key vectors come from outside sources, for example image embeddings obtained from an image encoder, then it is called “cross attention”.
Better Transformers
SwiGLU layer
SwiGLU ( Citation: Shazeer, 2020 Shazeer, N. (2020). Glu variants improve transformer. arXiv preprint arXiv:2002.05202. ) is an improvement over the feed forward layer in transformers. A feed forward layer (also called a fully connected (FC) layer, or multi-layer perceptron (MLP)) is defined as
$$ \mathrm{FFN}(x, W_1, W_2, b_1, b_2) = W_2f(W_1x + b_1) + b_2, $$
where $f$ is the nonlinear activation function applied element-wise to its input vector. The most commonly used one is the ReLU activation function $f(x) = \max(0, x)$.
The Swish function ( Citation: Ramachandran, Zoph & al., 2017 Ramachandran, P., Zoph, B. & Le, Q. (2017). Searching for activation functions. arXiv preprint arXiv:1710.05941. ) , also called Sigmoid Linear Unit (SiLU), is defined as $f_\beta(x) = x\cdot\sigma(\beta x)$, where $\sigma(x)$ is the sigmoid function
$$ \sigma(x) = \frac{1}{1 + e^{-x}}. $$
As $\beta\to0$, the function approaches the linear function $\frac{1}{2}x$. As $\beta\to\infty$, $\sigma(x)$ approaches either $0$ or $1$, so the Swish function approaches the ReLU function. It is common to set $\beta=1$. The most unique aspect of the Swish function is the non-monotonic “bump” at $x<0$.
Gated Linear Units (GLU) refers to multiplying the vector coming out of the nonlinear activation function $f(W_1x)$ by a linear transformation of the input vector $x$, in element-wise fashion. This is denoted as
$$ f(W_1x) \otimes W_3x. $$
Now, the SwiGLU layer is
(1) feed forward layer without bias, plus
(2) GLU, plus
(3) the Swish function $f(x)=x\cdot\sigma(x)$ as activation function:
$$ \mathrm{SwiGLU}(x,W_1, W_2, W_3) = W_2(\mathrm{Swish}_1(W_1x) \otimes W_3x) $$
Below is the implementation of SwiGLU layer from Llama. I have deleted some non-essential code so that we can focus on the main part.
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, ...)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, ...)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, ...)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
The code is very straightforward. The ColumnParallelLinear
and RowParallelLinear
layers are from the fairscale library that is designed for high performance and parallel training, and here we can simply regard them as linear layers (fully connected layers). The PyTorch F.silu
function stands for “Sigmoid Linear Unit (SiLU)”, and it is exactly the Swish function $x\cdot\sigma(x)$ with $\beta=1$.
RMSNorm layer
Neural network training is sensitive to parameter initialization and hyperparameter tuning. Batch normalization ( Citation: Ioffe & Szegedy, 2015 Ioffe, S. & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. pmlr. ) was proposed to make the training process more robust and less dependent on such hyperparameter tuning. The method can make training converge faster, thus reducing training time. When it was first proposed it achieved the best accuracy on ImageNet classification task. For each mini-batch, batch normalization first calculates mean and variance statistics over the mini-batch, then subtracts the mean and scales by the variance for each data point in the mini-batch. We can represent the batch normalization layer as
$$ \mathrm{BN}(x) = \frac{x - \mathbb{E}_{x\sim\mathcal{B}}x}{\sqrt{\mathrm{var}(x) + \epsilon}}, $$
where the mean and variance statistics are calculated over a batch $\mathcal{B}$.
Layer normalization ( Citation: Ba, Kiros & al., 2016 Ba, J., Kiros, J. & Hinton, G. (2016). Layer normalization. arXiv preprint arXiv:1607.06450. ) was first proposed to apply the same idea to recurrent neural networks (RNN), where an layer may take only one input vector at a time. Layer normalization is “orthogonal” to batch normalization. Instead of looking at the data distribution across a batch, it calculates mean and variance statistics across all dimensions of an input vector, and normalize each dimension with the calculated mean and variance. Layer normalization can be represented as
$$ \mathrm{LN}(x) = \frac{x - \mathbb{E}x_j}{\sqrt{\mathrm{var}(x_j) + \epsilon}}g, $$
where the mean and variance statistics are calculated over $x=(x_1,\ldots, x_j, \ldots, x_n)\in\R^n$. $g\in\R^n$ is a trainable parameter used for scaling that is set to $(1,\ldots,1)$ at the beginning.
Rooted Mean Square Normalization (RMSNorm) ( Citation: Zhang & Sennrich, 2019 Zhang, B. & Sennrich, R. (2019). Root mean square layer normalization. Advances in Neural Information Processing Systems, 32. ) proposed to focusing only on scaling invariance:
$$ \mathrm{RMSNorm}(x) = \frac{x}{\|x\|/\sqrt{n}}g, $$
where $\|x\| = \sqrt{x_1^2 + \cdots + x_n^2}$. The output is thus confined to a sphere with radius $\sqrt{n}$.
In Llama, RMSNorm is implemented as follows:
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x)
return output * self.weight
x.pow(2).mean(-1)
computes the squared norm of $x$. torch.rsqrt
computes $y = 1/\sqrt{x}$, the reciprocal square root. The weight parameter tensor has the same dimension as input $x$, and is initialized to $(1,\ldots,1)$.
Rotary Position Embedding (RoPE)
When modeling a token’s dependencies over previous tokens in a sentence, an issue to consider is how to encode the order of the sequence. Transformer (or self-attention) models each word vector as a weighted sum over all previous word vectors in a sequence. To see the issue of token positioning, let’s briefly review the self-attention layer. Let $(x_1,\ldots,x_N)$ be a sequence of word vectors, and suppose we want to predict the next vector $x$. The self-attention calculation can be summarized as the following equation: $$ f(x) = \frac{\exp(q_x\cdot k_1)}{\sum_{n=1}^N\exp(q_x\cdot k_n)}v_1 + \cdots + \frac{\exp(q_x\cdot k_N)}{\sum_{n=1}^N\exp(q_x\cdot k_n)}v_N $$ where $$ \begin{align*} q_x &= W^qx,\newline k_i &= W^kx_i, i=1,\ldots, N\newline v_i &= W^vx_i, i=1, \ldots, N \end{align*} $$ and $W^q, W^k, W^v$ are three weight matrices to be learned. Since the addition operation is commutative, We can see that arbitrarily permuting the sequence $(x_1,\ldots,x_N)$ would not affect the output, which is not desired. So we have to somehow encode position information into either the embedded vector $x_i$ or the query, key and value vectors. In the original Transformer paper ( Citation: Vaswani, Shazeer & al., 2017 Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A., Kaiser, Ł. & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30. ) , the authors proposed adding absolute position encoding to each input vector: $$ x_i = x_i + p_i $$ where $$ \begin{cases} p_{i, 2t}&= \sin(i/10000^{2t/d}),\newline p_{i, 2t+1}&= \cos(i/10000^{2t/d}). \end{cases} $$ Note that both sine and cosine functions have range in $[-1, 1]$, so each value in each position vector $p_i$ is also confined within $[-1, 1]$. The large number 10000 was chosen in order to encompass large context lengths, but theoretically, the uniqueness of the position vectors is bounded, since we have the following two basic trigonometric properties $$ \sin(x + 2k\pi) = \sin(x) $$ and $$ \cos(x + 2k\pi) = \cos(x). $$ Let’s denote the large scaling factor in the denominator by $S=10000$, and choose a position $m$. Then the position vector for $m$ will be the same for the word that is $2\pi S$ far away, since $$ \sin\left(\frac{m + 2\pi\cdot S}{S}\right) = \sin\left(\frac{m}{S} + 2\pi\right) = \sin\left(\frac{m}{S}\right). $$ This means the maximum context length within which the position vector is unique for each position is around 62K. Beyond that, the position vectors will repeat. 62K is a very large length even according to today (2023)’s standard. Larger sequence lengths put significantly more pressure on GPU memories when using the vanilla Transformer, as we will see below when we discuss Flash Attention.
Rotary position embedding ( Citation: Su, Lu & al., 2021 Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. & Liu, Y. (2021). Roformer: Enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864. ) is an alternative to additive position encoding. It is multiplicative, instead of additive. And it applies to the key and query vectors, rather than the input vectors. Let’s first see it in 2D case. Suppose we embedded each word token to a 2d vector, and calculated the query and key vectors. The RoPE method rotates each query or key vector with a rotation matrix. Take the query vector $q = (q_1, q_2)^T\in\mathbb{R}^2$ as an example:
$$\tag{🌟} \hat{q} = \begin{pmatrix}\cos m\theta & -\sin m\theta \newline \sin m\theta & \cos m\theta\end{pmatrix}\begin{pmatrix}q_1\newline q_2\end{pmatrix} = \begin{pmatrix} q_1\cdot\cos m\theta - q_2\cdot\sin m\theta \newline q_1\cdot\sin m\theta + q_2\cdot\cos m\theta\end{pmatrix} $$ where $m$ is the position and $\theta$ is a preset small non-zero constant. Word vectors in different positions are rotated with different amounts according to $m\theta$. To generalize to $d$ dimensions, where $d$ has to be an even number, the method applies rotation matrices to each 2 dimension pairs, for a total of $d/2$ pairs. Specifically, the transformation matrix is a block matrix where each block is a 2d rotation matrix: $$ \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \newline \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \newline \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \newline 0 & 0 & \cdots & 0 & 0 & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \newline 0 & 0 & \cdots & 0 & 0 & \sin m\theta_{d/2} & \cos m\theta_{d/2}\newline \end{pmatrix} $$ and the angles are pre-defined with $\{\theta_i = 10000^{-2(i-1)/d}, i=1,\ldots,d/2\}$.
Implementing RoPE as matrix multiplication takes lots of memories when $d$ is large. Representing each 2d sub-vector pair as a complex number will let us avoid matrix computation. Let’s re-focus on the 2d case. The 2d rotation matrix has close connections with arithmetics in the 2d complex plane. Recall the Euler’s formula $$ e^{i\theta} = \cos\theta + i\sin\theta. $$ The core observation that is very useful for our implementation is the following: for a complex number $q_1 + q_2i$, we have $$ \begin{split} (q_1 + q_2i)e^{im\theta} &= (q_1 + q_2i)(\cos m\theta + i\sin m\theta)\newline &= (q_1\cdot\cos m\theta - q_2\cdot\sin m\theta) + (q_1\cdot\sin m\theta + q_2\cdot\cos m\theta)i. \end{split} $$ The real and imaginary parts of this complex number are exactly the first and second entry of the transformed query vector above (🌟)! So to get the rotary embedding for vector $q$, we first view it as a complex number, multiply it by $e^{im\theta}$, then assemble the real and imaginary parts as the position encoded vector.
So, here are the implementation steps for one single 2d query/key vector $q$ with position $m$:
- View the 2d vector $q=(q_1, q_2)^T$ as a complex number $q_1 + q_2i$, where the first entry is the real part, and the second entry is the imaginary part. In PyTorch, this is
q = torch.view_as_complex(q)
- Multiply the vector by $e^{im\theta}$. In PyTorch, $e^{im\theta}$ is
torch.polar(1, mθ)
. So we do
q = q * torch.polar(1, mθ)
- Get the real part and the imaginary part of this complex number. In PyTorch, this is
q_out = torch.view_as_real(q)
The 2d implementation above showed the essence of RoPE. If you already understand the 2d case, then it is now easy to understand the code for general case. Here are the steps for $d$ dimensional query/key vectors , for all positions $m=0,\ldots,N-1$, with reference to Llama source code. The input will have dimension (batch_size, seqlen, n_heads, head_dim)
.
dim | meaning |
---|---|
0 | batch size |
1 | sequence length |
2 | number of heads |
3 | head dimension |
- View each vector as a group of $d/2$ complex numbers. View dimension $(1, 2)$ as a complex number, dimension $(3, 4)$ as a complex number……and so on.
torch.view_as_complex
expects the last dimension of its input to be of size 2, so we first reshape the input tensor as a list of 2 number pairs
q.reshape(*q.shape[:-1], -1, 2)
then get the list of complex numbers via
q_ = torch.view_as_complex(q.reshape(*q.shape[:-1], -1, 2))
- Multiply the first complex number by $e^{im\theta_1}$, multiply the second complex number by $e^{im\theta_2}$, and so on, until multiply the final complex number by $e^{im\theta_{d/2}}$. In llama source code, the complex numbers $\{e^{im\theta_1},e^{im\theta_2},\ldots,e^{im\theta_{d/2}}\}$ for all $m=0,\ldots,N-1$ are prepared by
def compute_exps(dim: int, N: int):
thetas = 1.0 / (10000.0 ** (torch.arange(0, dim, 2) / dim))
m = torch.arange(N, device=thetas.device)
angles = torch.outer(m, thetas).float()
exps = torch.polar(torch.ones_like(angles), angles) # complex64
return exps
To avoid potential confusions, I have renamed several variables so as to match the corresponding variable names in the original paper. Now let’s examine the this function line by line:
(1)
thetas = 1.0 / (10000.0 ** (torch.arange(0, dim, 2) / dim))
This line prepares all the thetas $\{\theta_i = 10000^{-2(i-1)/d}, i=1,\ldots,d/2\}$. The dimension of the output tensor thetas
will be $d/2$, half the dimension of input $x$.
(2)
m = torch.arange(end, device=thetas.device)
This line prepares all the positions $(0, 1, \ldots, N-1)$.
(3)
angles = torch.outer(m, thetas).float()
This line prepares all the angles $m\theta_i$, for all $m$ and all $i$. For two tensors $a$ and $b$ with size $N$ and $d/2$, torch.outer(a, b)
is $a^Tb$, so the output has dimension $N\times(d/2)$.
(4)
exps = torch.polar(torch.ones_like(angles), angles) # complex64
This line computes $\{e^{im\theta_i}\}$, for all $m$ and all $\theta_i$. Thus the output of the compute_exps
function is a matrix, where each row is $\{e^{im\theta_1},e^{im\theta_2},\ldots,e^{im\theta_{d/2}}\}$ for a particular position $m$, for all positions $m=0,\ldots,N-1$.
Suppose now we have prepared a query vector q_
as a list of complex numbers and we also have complex exponentials as exps
. We simply do
q_ * exps
to multiply them together.
- For each 2d dimension pair $(1, 2), (3,4), \ldots, (d-1, d)$, we have computed a complex number $q\cdot e^{im\theta}$, for a total number of $d/2$. Now we get the real and imaginary part of each complex number, and stack all of them together to get the output.
torch.view_as_real
is the inverse operation oftorch.view_as_complex
. That means the last dimension of the output fromtorch.view_as_real
will be 2. That’s why we need to call.flatten(3)
to flatten the last dimension (remember from the input dimension table above that 3 is the last input shape position).
def apply_rotary_emb(
q: torch.Tensor,
k: torch.Tensor,
exps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
exps = reshape_for_broadcast(exps, q_)
q_out = torch.view_as_real(q_ * exps).flatten(3)
k_out = torch.view_as_real(k_ * exps).flatten(3)
return q_out, k_out
We mention that, although RoPE is claimed to be a relative position embedding method, meaning that the similarity measure $\hat{q}_m\cdot\hat{k}_n$ between the position encoded query $\hat{q}_m$ (with position $m$) and position encoded key $\hat{k}_n$ (with position $n$) only depends on $m-n$, their relative position, not their absolute positions in a sentence, we see that when implementing RoPE we still have to apply rotations to each query vector and each key vector in a sequence according to their absolute positions $m$ and $n$ in the sequence, obtaining $\hat{q}_m$ and $\hat{k}_n$, before computing $\hat{q}_m\cdot\hat{k}_n$. This is because, query and key vectors are needed to compute the attention matrix. In this sense, RoPE is not true relative position embedding.
Also, RoPE has long term decay property: the dot product of two position encoded vectors approaches zero as their distance increases. The assumption that a pair of long distance tokens should have less connection is not desirable in light of LLM applications to long texts, e.g. contracts, legal documents, and financial reports. For example, users could give an instruction at the end of a prompt to extract some information from a long piece of text, where such information could be located right at the very beginning of the text. The long term decay property could lead to degraded performance in such cases.
ALiBi position embedding
Attention with Linear Biases (ALiBi) ( Citation: Press, Smith & al., 2021 Press, O., Smith, N. & Lewis, M. (2021). Train short, test long: Attention with linear biases enables input length extrapolation. arXiv preprint arXiv:2108.12409. ) is an even simpler method to encode position information.
$m$ is a fixed scalar and is not learned. Ignoring $m$ for a second, the above figure has clear illustration of the method: biases $(-(N-1), \ldots, -2, -1, 0)$ are added to query and key dot products before applying softmax. For example, $-5$ is added to $q_5\cdot k_1$, $-1$ is added to $q_5\cdot k_4$, no bias is added to $q_5\cdot k_5$, and so on. There is no other positional encoding added elsewhere. Clearly, as distance gets larger, the negative bias will get large, and $\exp\{q\cdot k - N\}$ will decay to $0$ very fast as $N\to\infty$, resulting in 0 score and thus no lookups to distant tokens when generating the next token. As for RoPE, this property may not be desirable in real applications where one wants an LLM to process professional, long and complicated documents.
Flash Attention
Flash Attention ( Citation: Dao, Fu & al., 2022 Dao, T., Fu, D., Ermon, S., Rudra, A. & Ré, C. (2022). FlashAttention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35. 16344–16359. ) is a technique for speeding up attention computation on Nvidia GPUs for long sequences. There exists a memory heirarchy on GPU: high bandwidth memory (HBM) takes the bulk of GPU memory, but it is slower compared to SRAM, which is far smaller in size but is a dozen times faster than HBM.
For a sequence with length $N$, the attention computation defined by the authors is $$ \mathrm{O} = \mathrm{dropout}(\mathrm{softmax}(\mathrm{mask}(QK^T)))V $$
Memory usage grows quadratically with the sequence length $N$. As $N$ goes to thousands and beyond, the attention matrix $QK^T$ will occupy a lot of space in GPU memory. Masking, softmax and dropout computation all require reading and writing the $N\times N$ attention matrix. It turns out that those memory reads/writes take up most of the computation time, much more than the actual matrix multiplication $QK^T$.
To reduce memory i/o, and to take advantage of SRAM’s fast speed, the authors proposed two techniques:
Tiling. This refers to loading partial blocks of the attention matrix from HBM to SRAM, computing attention on SRAM for such blocks, then finally concatenating the results by the correct scaling factors. This technique utilizes that fact that for a matrix $A = [A_1, A_2]$ that consists of two sub-matrices $A_1$ and $A_2$, the softmax of $A$ can be written as $$ \mathrm{softmax}([A_1, A_2]) = [\alpha\cdot\mathrm{softmax}(A_1), \beta\cdot\mathrm{softmax}(A_2)] $$ for some scaling factors $\alpha$ and $\beta$, and $$ \mathrm{softmax}([A_1, A_2])\begin{bmatrix}V_1\newline V_2\end{bmatrix} = \alpha\cdot\mathrm{softmax}(A_1)V_1 + \beta\cdot\mathrm{softmax}(A_2)V_2, $$
Recomputation. During the backward pass, the attention matrix is needed for computing gradients of the weights, and again, reading this large matrix from HBM memory would be slow. The recomputation technique is straightforward: do not store the $N\times N$ attention matrix from forward pass (only store $N$ softmax normalizing factors), but recompute it in the backward pass. This incurs additional floating point operations (flops), but the runtime is reduced even with increased flops.
Attention Standard FlashAttention GFLOPs 66.6 75.2 HBM R/W (GB) 40.3 4.4 Runtime (ms) 41.7 7.3 Source: ( Citation: Dao, Fu & al., 2022 Dao, T., Fu, D., Ermon, S., Rudra, A. & Ré, C. (2022). FlashAttention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35. 16344–16359. )
For the same sequence length $N$, Flash Attention allows faster computation compared with traditional attention. In other words, this means that for the same training time budget, one can train models that can deal with longer contexts. In this aspect, Flash Attention is very useful, because such models are very much needed in industrial applications.
Llama source code
We have already walked through some of Llama’s source code, including the feedforward layer, the RMSNorm layer, and the apply_rotary_emb
function. In this last section, I will walk through the rest of Llama’s source code. First let’s look at the model.py
file. We are left with three classes:
class Attention(nn.Module):
...
class TransformerBlock(nn.Module):
...
class Transformer(nn.Module):
...
Now let’s study them one by one.
The Attention layer
The attention layer is implemented as follows. Again, I remind the reader that I have refactored the code and removed non-essential parts. What the Attention layer does is to multiply the query matrix with the key matrix, apply softmax to get the scores, then multiply the scores with values, and finally apply a linear layer to get the output.
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.head_dim = args.dim // args.n_heads
self.wq = Linear(args.dim, args.n_heads * self.head_dim, bias=False, ...)
self.wk = Linear(args.dim, args.n_heads * self.head_dim, bias=False, ...)
self.wv = Linear(args.dim, args.n_heads * self.head_dim, bias=False, ...)
self.wo = Linear(args.n_heads * self.head_dim, args.dim, bias=False, ...)
def forward(
self,
x: torch.Tensor,
start_pos: int,
exps: torch.Tensor,
mask: Optional[torch.Tensor],
):
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk = apply_rotary_emb(xq, xk, exps=exps)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
scores = torch.matmul(xq, keys) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, values) # (bs, n_heads, seqlen, head_dim)
output = output.transpose(1, 2)
return self.wo(output)
The TransformerBlock layer
A TransformerBlock
layer in Llama is:
x --->RMSNorm --->Attention ------>RMSNorm --->FeedForward --->out
| | | |
| | | |
·------------ + ------------· ·------------ + ------------·
Different from the original transformer architecture, the normalization layer is placed at the beginning, rather than the end. Below is the simplified pseudocode implementation. The self.attention
attribute is Attention layer, self.feed_forward
is SwiGLU layer. self.ffn_norm
and self.attention_norm
are both RMSNorm layers.
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.head_dim = args.dim // args.n_heads
def forward(self, x, start_pos):
h = x + self.attention.forward(
self.attention_norm(x), start_pos
)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
The Transformer layer
Finally, the Llama model is defined as a Transformer
class. It is a for loop of TransformerBlock
s. The last dimension of Transformer
output is un-normalized probability scores over the vocabulary. It has size vocab_size
. Different sampling methods make different use of the scores. The basic greedy sampling chooses the one with the largest output score. Top-k sampling samples from top k highest score tokens.
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.tok_embeddings = Embedding(params.vocab_size, params.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = Linear(params.dim, params.vocab_size, bias=False)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
h = self.tok_embeddings(tokens)
for layer in self.layers:
h = layer(h, start_pos)
h = self.norm(h)
output = self.output(h)
return output
The Transformer
class above is the Llama model, which is pretty concise and clean. Having looked at the Llama model, now let’s look at how inference is implemented.
Generation
There is one Llama
class defined in generation.py
. First, the build
method loads weights from local paths, and initializes the Llama model. It is used in official examples example_chat_completion.py
and example_text_completion.py
to build a generator
. The generator
exposes two interfaces, text_completion
and chat_completion
. Both of the two are variants of the generate
method.
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> "Llama":
tokenizer = Tokenizer(model_path=tokenizer_path)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
return Llama(model, tokenizer)
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, ...):
pass
The core part of the generate
method is the following lines of code: get the logits from model’s forward
method, then sample the next token from the logits. As mentioned before, greedy sampling chooses the token with the largest output score; top-k sampling samples from tokens with k largest scores; and top-p sampling is top-k sampling where k summed probabilities exceeds a certain threshold, so that k is dynamically adjusted.
# ......
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], ...]:
params = self.model.params
bsz = len(prompt_tokens)
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
Note that the temperature parameter appears as the denominator of the softmax function input. If temperature is zero, then we simply have greedy sampling. If temperature is small, this would magnify those logits that are already large, so that other tokens will have even slimmer chance of being selected. 1.0
is a neutral choice. At values larger than 1.0
, temperature will reduce all the logits, reducing their differences after exponentials, so that sampling will be more random.
Text completion:
# ......
def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
) -> List[str]:
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
Chat completion:
class Message(TypedDict):
role: Role
content: str
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant..."
# ......
def chat_completion(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[Message]:
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = []
for dialog in dialogs:
if dialog[0]["role"] != "system":
dialog = [
{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}
] + dialog
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: List[int] = sum(
[
self.tokenizer.encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
bos=True,
eos=True,
)
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += self.tokenizer.encode(
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
bos=True,
eos=False,
)
prompt_tokens.append(dialog_tokens)
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
)
return [
{"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}}
for t in generation_tokens
]
As with text_completion
, chat_completion
is also a wrapper around the generate
method. Chat histories are concatenated to a string, before feeding into the generate
method. To distinguish between users’ prompts and assistant’s answers, special tokens like [INST]
and [/INST]
are inserted:
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} "
A pre-trained model is fine-tuned on chat data with such format, so that it should “recognize” such format in inference. Here is the description from the Llama 2 paper ( Citation: Touvron, Martin & al., 2023 Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S. & (2023). Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288. ) :
For the fine-tuning process, each sample consists of a prompt and an answer. To ensure the model sequence length is properly filled, we concatenate all the prompts and answers from the training set. A special token is utilized to separate the prompt and answer segments. We utilize an autoregressive objective and zero-out the loss on tokens from the user prompt, so as a result, we backpropagate only on answer tokens. Finally, we fine-tune the model for 2 epochs.
self.tokenizer.encode
will encode the input string to a list of integers. So the type of
[
self.tokenizer.encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
bos=True,
eos=True,
)
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
]
will be list[list[int]]
, a list of integer lists. sum(lst, [])
where lst
is of type list[list[int]]
will flatten the list (remove all inner lists), so the output of that will be list[int]
. For example,
In [1]: sum([[1,2,3], [4,5,6], [7,8,9]], [])
Out[1]: [1, 2, 3, 4, 5, 6, 7, 8, 9]
This dialog_tokens
list (type list[int]
) is then appended to prompt_tokens
(type list[list[int]]
), and finally this list is fed to self.generate
method to get the generated tokens. This is how text_completion
and chat_completion
work under the hood.
Summary
In this post, I talked about various techniques for improving the transformer architecture, including SwiGLU, RMSNorm, Rotary Position Embedding (RoPE), ALiBi and Flash Attention. I have also walked through Llama’s source code. My post could help AI practitioners better understand LLMs’ behaviors as well as how to use them properly. With more open source LLMs coming out, this post could also help provide a direction for finding models that best suit various application scenarios.