Transformer models are at the heart of generative AI (artificial intelligence) and machine learning. These models often involve substantial computational and memory costs due to their large number of weights (parameters). Therefore, there exists a need to reduce the number of weights and enhance the computational complexity of transformer neural network architectures. Reducing the number of weights in such architectures leads to higher computational efficiency, lower latency, higher throughput, smaller memories, and lower cost per token.
Existing methods for simplifying transformers, such as those described by Bobby He and Thomas Hofmann in “Simplifying Transformer Blocks” (arXiv preprint arXiv: 2311.01906, 2023), focus on specific types of attention mechanisms but are not universally applicable to multi-query attention (MQA) or grouped-query attention (GQA), which are widely employed in large language models (LLMs) like Llama 2 and Mistral. Therefore, there exists a need for simplified transformer architectures that are applicable to various attention mechanisms including MQA, GQA, and MHA (multi-head attention).
This specification describes various systems and methods for modified and simplified transformer neural network architectures with and without skip connections (aka residual connections) that are suitable for various types of attention mechanisms including MHA, MQA, and GOA. The modified architectures reduce the number of weights while maintaining functionality and model accuracy and are mainly derived by merging and eliminating one or more linear layers (also known as fully-connected layers, dense layers, or projection layers) of the transformer architecture.
By merging and eliminating certain projection linear layers, the architectures described in this specification achieve significant weight reductions. For example, eliminating the attention query (Q) projection layers and the post-attention output projection (P) layers from a skipless version of Mistral-7B removes 15% of its weights and thus reduces its compute and memory complexity.
Bobby He et al. detailed in “Deep Transformers without Shortcuts: Modifying Self-attention for Faithful Signal Propagation” (ICLR, 2023. arXiv preprint arXiv: 2302.10322, 2023) how transformer neural networks without skip connections and normalization layers can be successfully trained.
The removal of the skip connections and normalization layers enables linear layers to be merged (or collapsed) in a mathematically equivalent way as shown in
d2 weights per transformer block by merging Pi into Mi* and Qi into Oi-1*.
For MHA (multi-head attention) where e=d, d2 weights per transformer block by merging Pi into Mi* and Ki or Vi into Oi-1*.
In total, we can eliminate up to two of the four attention linear layers (Q, K, V, P).
The merging of linear layers described above requires that Qj, Ki, and Vi are invertible (i.e. nonsingular). It is extremely rare that a square matrix with random values is not invertible, which requires its determinant to be exactly 0.
Dimension d is the embedding dimension. And dimension e is the output dimension of the key (K) and value (V) linear layers, where e=d holds for MHA. For MQA, e=d/n_heads holds. And for GQA, e=d*n_kv_heads/n_heads holds.
Dimension f is the hidden dimension of the FFN (feedforward network, aka MLP or multi-layer perceptron). f=4d in the original transformer; but oftentimes f>4
d. For models that use a GLU variant (such as transformer models Llama 2 and Mistral), the effective f for the first linear layer M is f=2
f, because the GLU variant uses two linear layers that are combined (via pointwise multiplication) with a non-linear activation function.
The blocks labeled “attention” in all figures of this specification refer to any type of attention such as scaled dot-product attention and shaped attention (see for example Lorenzo Noci et al. “The shaped transformer: Attention models in the infinite depth-and-width limit,” arXiv preprint arXiv: 2306.17759, 2023), as well as MHA, MQA, and GQA.
For MHA where e=d holds, Ki can be removed as shown in
For MHA where e=d holds, Vi can be eliminated as shown in
Note that for multi-query attention (MQA) and grouped-query attention (GQA), eliminating K or V is not possible (because here dimension e doesn't equal d). In this case, eliminating Q is the only possibility, which also saves more compute and weights because Q is h times larger than K or V for MQA, where h is the number of attention heads. Specifically, note that
The table below specifies how the new weight matrices (Mi*, Qi*, Ki*, Vi*, Oi-1*) of
Qi
Ki
Vi
Qi
Qi
Ki
Ki
Vi
Vi
Mi
Similar to the parallel transformer (see for example Wang, B. and Komatsuzaki. “A. GPT-J-6B: A 6 billion parameter autoregressive language model”),
Alternative embodiments of this invention include versions with skip connections and normalization layers, as shown in
So far we have shown architectures without normalization layers, such as layer normalization and RMS normalization. In general, normalization can be added before the attention network and feedforward network (which is known as Pre-LN or Pre-Norm) or after the attention network and the feedforward network (known as Post-LN or Post-Norm). Similarly, normalization layers can be added anywhere to the architectures detailed in the preceding sections.
For example,
Practical examples are provided in the table below, which lists the configurations and number of weights (or parameters) for the transformer models Pythia-6.9B and Mistral-7B. A skipless version of Mistral-7B saves 15% of its weights after merging the Q and P linear layers into the FFN layers. For a batch-size 1 system that is limited by memory bandwidth, these 15% weight savings can speed up inference by 1.17× during the autoregressive next-token-generation phase as detailed in the table below.
This application claims priority to the U.S. Provisional Application No. 63/609,293, filed on Dec. 12, 2023. The entire contents of the foregoing provisional application is hereby incorporated by reference.
| Number | Date | Country | |
|---|---|---|---|
| 63609293 | Dec 2023 | US |