This application is drawn to improvements in the performance of neural networks, and specifically to techniques for increasing throughput through neural networks.
Deep learning is ubiquitous in today's world, for use with robots and autonomous vehicles, to improvements for voice-powered assistants, no aiding in scientific breakthroughs in a variety of fields. Deep learning approaches typically utilize large models for improving effectiveness. For example, OpenAI's GPT-3 model utilizes over 100 billion parameters in its model. These large models consume a very large amount of compute resources and require large amounts of energy to power those resources. By one estimation, producing enough energy for three months of running one instance of the GPT-3 model would generate more than 14,256 pounds of CO2 emissions.
Such requirements are a heavy burden as these models are scaled up and democratized. Thus, there is a need for improving the performance of such models.
Various deficiencies in the prior art are addressed below by the disclosed compositions of matter and techniques.
In one aspect of the present disclosure, a method for improving the throughput of a neural network may be provided, with various versions sometimes including the term “DataMUX”, including “PT-DataMUX”. The method may include a multiplexing phase of receiving a plurality of inputs, and generating transformed inputs by performing, via a multiplexing layer, a transformation (such as a fixed linear transformation, or other transformation) to each input of the plurality of inputs. These transformed inputs are then combined into a single compact representation of the plurality of inputs. In some embodiments, the single compact representation may be transmitted to a base neural network. After the base neural network performs its operations, a demultiplexing phase may occur, wherein output from the neural network is received, and a plurality of distinct output values are generated by converting, via a demultiplexing layer, the output back into independent representations. These distinct output values can then be used to produce predictions for each input.
In some embodiments, the demultiplexing layer may utilize a multihead neural network, such as a multilayer perceptron. In some embodiments, the demultiplexing layer may use a set of input-specific keys or indices, which may include index embedding.
In some embodiments, a training phase may be used. The training phase may include a retrieval warmup step, which may include retrieving correct tokens and order for each position and sequence of the plurality of inputs. In some embodiments, the training phase may include pretraining the neural network after the warmup step, which may use a masked language modeling objective. In some embodiments, the training phase may include finetuning the neural network after pretraining, which may include training on a specific downstream task.
In some embodiments, one or more transformers in the model, such as a transformer between the multiplexing layer and a demultiplexing layer, may be transformed via pruning and/or distillation. In some embodiments, the method may include a prediction step, using a task accuracy model and a throughput model, that predicts parameters for improving throughput while also meeting a given accuracy budget.
In another aspect of the present disclosure, a non-transitory computer-readable storage medium may be provided. The non-transitory computer-readable storage medium may contain instructions that, when executed, cause a processor to perform operations that include some or all of the embodiments of the method as disclosed herein.
In another aspect of the present disclosure, a system may be provided. The system may include a processor operably coupled to a non-transitory computer-readable storage medium. The non-transitory computer-readable storage medium may contain instructions that, when executed, cause the processor to perform operations that include some or all of the embodiments of the method as disclosed herein.
In another aspect of the present disclosure, a neural network apparatus may be provided. The apparatus may include a processor operably coupled to memory. The processor may be configured to generate a neural network with a plurality of layers. The plurality of layers may include (i) a multiplexing layer configured to perform a fixed linear transformation to each received input before combining them into a single compact representation; (ii) one or more layers defining a base neural network, the base neural network configured to receive output from the multiplexing layer; and (iii) a demultiplexing layer configured to convert output of the base neural network back into independent representations. The multiplexing layer may be further configured to perform a retrieval warmup to promote distinguishing order and content of individual sequences in a multiplexed representation. In some embodiments, the one or more layers may include a transformer that has been transformed via pruning and/or distilling.
The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate embodiments of the present invention and, together with a general description of the invention given above, and the detailed description of the embodiments given below, serve to explain the principles of the present invention.
The following description and drawings merely illustrate the principles of the invention. It will thus be appreciated that those skilled in the art will be able to devise various arrangements that, although not explicitly described or shown herein, embody the principles of the invention and are included within its scope. Furthermore, all examples recited herein are principally intended expressly to be only for illustrative purposes to aid the reader in understanding the principles of the invention and the concepts contributed by the inventor(s) to furthering the art and are to be construed as being without limitation to such specifically recited examples and conditions. Additionally, the term, “or,” as used herein, refers to a non-exclusive or, unless otherwise indicated (e.g., “or else” or “or in the alternative”). Also, the various embodiments described herein are not necessarily mutually exclusive, as some embodiments can be combined with one or more other embodiments to form new embodiments.
The numerous innovative teachings of the present application will be described with particular reference to the presently preferred exemplary embodiments. However, it should be understood that this class of embodiments provides only a few examples of the many advantageous uses of the innovative teachings herein. In general, statements made in the specification of the present application do not necessarily limit any of the various claimed inventions. Moreover, some statements may apply to some inventive features but not to others. Those skilled in the art and informed by the teachings herein will realize that the invention is also applicable to various other technical areas or embodiments.
References listed herein are incorporated by reference in their entirety as if fully set forth herein.
In one aspect of the present disclosure, a method for improving the throughput of a neural network may be understood with reference to
Various transformations can be used here. In some embodiments, a fixed linear transformation is used. However, other transformations may also readily be used, including a non-linear function that is either fixed or learnable.
This can be seen in
In some embodiments, the single compact representation may then be transmitted 20 to a base neural network 120.
The architecture of the base neural network may include any network architecture as understood by those of skill in the art. This may include the use of, e.g., Transformers, Multi-layer Perceptrons (MLPs), or Convolutional Neural Networks (CNNs).
The multiplexer preserves the order of the combined instances and therefore allows mixed instances to be used during inference. The primary goal is to process a set of input instances simultaneously over a shared neural network (multiplexing) with minimal overhead during inference. To this end, the multiplexer module may be designed to combine the multiple input instances into a superposed representation.
The multiplexer module, denoted Φ, combines a tuple of inputs, e.g., images or sentences from a batch, (x1, . . . , xN), for xi∈d, into a more compact representation x1:N∈
d in an order-dependent way, which enables effective demultiplexing after processing, as well as distinguishing intra-sequence interactions in the case of sequenced inputs (e.g., token sequences). Towards this end, for each input xi with index i∈[1, N] of the input tuple, the multiplexer module performs a transformation ϕi(
d→
d), on the instance before finally averaging all inputs into a single multiplexed representation as:
For sequenced inputs (e.g., token sequences), N sequences are combined by multiplexing token-wise. That is, for inputs of the form xi={wji}j∈[1,L], where wji∈d is a token's input vector representation, this operation uses the same transformation ϕi for each token in the sequence before averaging over each position across indices, such that x1:N={wji}j∈[1,L], where wj1:N∈
d is a multiplexed representation of N tokens at position j, i.e., x1:N={Φ(wji, . . . , wjN)}j∈[1,L].
For ϕi, various options can be used. For example, one can experiment with using either (1) a linear projection with a random fixed orthogonal matrix (denoted “Ortho”) or (2) the Hadamard product with a fixed Gaussian random vector (denoted “Hadamard”, equivalent to a linear map using a diagonal matrix here). These transformations map instances at different indices into distinguishable regions and consequently reduce interference between their representations. Finally, this multiplexed representation, x1:N, is used as input to the neural network backbone (or “base neural network”), which is architecturally unchanged.
After the base neural network performs its operations, a demultiplexing phase 30 may occur. There, output 125 from the base neural network 120 may be received 32, and a plurality of distinct output values 135 are generated 34 by converting, e.g., via a demultiplexing layer of a demultiplexer 130, the output back into independent representations.
In some embodiments, the demultiplexing layer may utilize a multi-head neural network. The multi-head neural network may be, e.g., a multilayer perceptron with non-linear activations or a Transformer.
In some embodiments, the demultiplexing layer may utilize a set of input-specific keys or indices, such as index embedding.
The output of the neural network backbone will generally be a multiplexed hidden representation h1:N of the input x1:N. One can explicitly disentangle h1:N into N individual hidden representations, h1, . . . , hN, respectively. One first obtains each hi with a demultiplexing function i, i.e.,
For sequenced input, demultiplexing is done position-wise, i.e., hji=i(hj1:N) for each position j.
In various embodiments, the demultiplexing function i may be modified.
In some embodiments, the demultiplexing function uses MLP demuxing. Here, N MLPs are employed to generate each indices' hidden representation as to generate each indices' hidden representation as hi=MLPi(h1:N). This approach may be used for, e.g., both natural language processing (NLP) and vision tasks. Although this method is conceptually simple, it adds learnable parameters proportional to N.
In some embodiments, the demultiplexing function utilizes index embeddings. Here, index embeddings pi are generated, which are then concatenated to h1:N, and transformed by a shared multi-layer network to generate each individual hidden representation, i.e., hji=MLPshared(hj1:N, pi). To generate the index embeddings pi, one can add a sequence of N special tokens, called the prefix, to the beginning of each sequence of the input tuple. For multiplexing with N sequences, one can add N corresponding prefixes, denoted prefixi for i∈[1, N]. Each prefixi consists of an index token ∈i in its i-th position while the remaining tokens are a special pad token ∈pad. The prefix sequences then take on the following pattern:
One then prepends each sequence xi of the input sequence with the corresponding prefixi. The tuple of prepended sequences is then passed to the multiplexing module. When finally generating individual hidden representations, one can use the corresponding hidden representation of each index token εi as the index embedding pi. One can use the Index Embeddings demultiplexing strategy on, e.g., language tasks for the Transformer architecture. For example, the prefix tokens may implicitly enable the Transformer to do instance-specific computations when processing the multiplexed representation and further enable demultiplexing for large N.
Finally, these distinct output values/independent representations (i.e., values 135) from the demultiplexing phase can then be used to produce 40 predictions 145 for each input. Predictions can be made using a shared task head 140 on each inputs' respective individual hidden representation to prevent a substantial increase in the number of parameters and improve training efficiency.
These methods may be implemented in a variety of ways. In some embodiments, a system may be provided. Referring to
Thus, a neural network apparatus may be provided, that may include a processor operably coupled to memory. The processor may be configured to generate a neural network with a plurality of layers. The plurality of layers may include (i) a multiplexing layer configured to perform a fixed linear transformation to each received input before combining them into a single compact representation; (ii) one or more layers defining a base neural network, the base neural network configured to receive output from the multiplexing layer; and (iii) a demultiplexing layer configured to convert output of the base neural network back into independent representations.
In some embodiments of the disclosed method, a retrieval warmup step may be utilized. For example, when using transformer models, it has been found that naively adding the multiplexing and demultiplexing layers to the model may sometimes fail to converge. This is likely because the gradients for individual instances from the task loss get mixed up in the backward pass. To overcome this, a Retrieval warm-up step is proposed, e.g., for use with sequenced inputs. This is a self-supervised pre-training task to promote the ability of the disclosed models at distinguishing the order and the content of individual sequences in a multiplexed representation. This task consists of retrieving the correct tokens and order for each position and sequence of the input tuple. See
In
In principle, one could add this loss for every sentence in each position. However, due to memory constraints, a preferred approach is to instead retrieve a token from a random sentence for each token position, yielding the following objective:
where hjI is a demultiplexed hidden representation of the j-th token in a randomly selected sentence with the index I˜[1,N], generated using the methods described above relating to index embedding.
The capabilities and limits of data multiplexing specifically for the Transformer architecture were evaluated on a range of text classification tasks. DataMUX was applied on a 12-layer Transformer based model with a hidden dimension size of 768 and 12 self-attention heads built on the Huggingface framework, and refer to the resulting model as T-MUX. The T-MUX models are compared to 2 baselines: (B1) a 12-layer 768 hidden dimension vanilla Transformer, and (B2) a 12-layer 768 hidden dimension Transformer pretrained using the retrieval task described herein. Though there is no multiplexing done for B2 (meaning this operation could be solved by simply copying input tokens to the output layer) it is found that this produces differences in performance and for completeness, this baseline is shown. After performing an empirical study to determine smaller transformer models that performed best, DataMUX was also applied to two smaller transformer models: (i) a 12 layer with hidden size of 384 (12L/384H), and a 4 layer with hidden size of 768 (4L/768H), which were compared with similar baselines.
The models and baselines were evaluated on two types of text classification tasks:
The T-MUX models were all pre-trained using the retrieval warm-up on the Wikitext-103 dataset (Merity et al., 2017) as described previously. In addition, the example also continued to use the retrieval task as an auxiliary objective during task training. The total loss is a combination of the task loss and retrieval loss, i.e.,
where the coefficient α=0.1 in these experiments.
In these examples, it was found that multiplexing leads to minimal drop in performance even for large N.
It is observed that it is possible to multiplex up to 40 instances and maintain reasonable performance. For easier tasks like QQP, SST2 and QNLI, it is observed that the drop in performance with increasing N is insignificant while for more difficult tasks like MNLI and NER, there is a trade-off between performance and N, with performance dropping 10%-15% for 40 instances. Since unstable optimization were encountered when using MLP Demuxing, only results using Index Embeddings demultiplexing were provided. It was found that the multiplexing strategy does not impact performance across different tasks, even slightly increasing for small values of N (e.g., 2, 5). This increase may be attributable to implicit regularization.
In these examples, it was found that there was perfect multiplexing on the retrieval warm-up task for large N.
In these examples, it was found that throughput can be increased multi-fold using the disclosed techniques. Throughput of a multiplexed model (Hadamard+Index Embed) was measured across different number of instances by calculating inference speed for processing ˜20,000 samples on the MNLI dataset on a single Nvidia RTX 2080 machine. Four different batch sizes were used for all the configurations and the maximum throughput was taken.
In these examples, attention heads seems unrelated to multiplexing. To understand the role of the number of attention heads in multiplexing, a variant of T-MUX with 2 self-attention heads per layer was trained, and specifically, T-MUX with the (Hadamard+Index Embed.) configuration. It was found that reducing self-attention heads has minimal effect on performance on the retrieval warm-up task and achieves a retrieval accuracy of ˜100% up to N=20. It was found that on token and sentence-level classification tasks, T-MUX with 2 self-attention heads performs comparably to T-MUX with 12 self-attention heads.
The disclosed techniques also provide throughput boosts with smaller Transformers. Two smaller Transformers were chosen to multiplex: a 12 layer with hidden size of 384 (12L/384H), and a 4 layer with hidden size of 768 (4L/768H).
With these techniques, performance varies more across different indices as N increases. The prediction of an instance is conditioned on the index of the instance.
Further, one can construct self-attention based neural networks which always process multiplexed token embeddings in N independent subspaces. One possible construction relies on particular structures of singular spaces of linear transformations across layers.
The multi-head attention projects the queries, keys, and values h times with different, learned linear projections to dK, dK, and dV dimensions respectively. I.e., given a sequence of d-dimensional multiplexed token embeddings {wt1:N}t∈[1,L], each head is
Since
and it is assumed each function ϕk projects each embedding into a subspace which is least linearly-dependent with the others. I.e., defining
it is assumed ut(k),ut(k′)
≈0 for all pairs of indices k≠k and all positions t. To preserve such independent subspace structure after projection, one will first need
This is achievable if the eigenvectors of WiV1, . . . ,
N. In this case, the vector after transformed by WiV can be expressed as a sum of N vectors vi,t(1), . . . , vi,t(N) in dual subspaces
V1, . . . ,
VN which are independent of each other. The linear maps from
k to
Vk still allows rich operation on each component of a multiplex input without interference by other components.
In addition to decompress-able value vectors, one can set the query and key matrices, WiQ and WiK, to have some subsets of right and left singular vectors such that span 1, . . . ,
N and
V1, . . . ,
VN. Then one can show that the inner product between the query and keys of the i-th head can be rewritten as
where τi,t,t(k), is a scalar only depending on the k-th input sequence. Thus, the self-attention operation at each position can be seen as retrieving values based on the average of query-key similarity scores of N sequences.
The average retrieval by soft-max could be a desired property as implicit regularization. However, if one wants perfect non-interference in retrieval, the network always has an option to specialize each head to only focus on one input sequence, by setting τi,t,t(k′)=0 for all k′≠k, which is easily achievable by controlling singular values of WiQ or WiK.
Certain cases of data multiplexing have been explored for convolutional architectures on image classification as techniques for robustness and data augmentation. Those works implicitly employed the frontend layers of a convolutional net as multiplexing layers and suggest that a convolutional neural network can learn at most 3-4 independent subnet-works concurrently. However, for Transformers, the disclosed technique has demonstrated multiplexing up to 40 instances without a severe performance drop. To better understand data multiplexing on such alternative architectures, one can apply DataMUX to MLPs and CNNs for the MNIST image classification task.
, one would expect vTGTHu=0 for all v and u∈
. If the dimensionality of the manifold
is not too low, then we will need GTH=0, which implies rank(G)+rank(H)≤d, but the ‘Ortho’ transformations are full rank. A low-rank transformation could help enable zero interference among inputs. This example therefore generated a set of N low-rank independent transformations for multiplexing (MLP+LowRank) and it was observed that this improves performance for larger N.
Different multiplexing strategies for CNN architectures were compared, with performance results shown in
Additionally, it is noted that a training phase, which may include, e.g., pretraining techniques, can be used effectively with the above to improve overall performance of the models. The training may include, e.g., a warmup step, a pretraining step, a finetuning step, or a combination thereof. In a preferred embodiment, the training includes a warmup step, a pretraining step, and a finetuning step.
In some embodiments, the training phase may include a warmup step comprising retrieving correct tokens and order for each position and sequence of the plurality of inputs. This process may include multiplexing multiple sequences of tokens, feeding the multiplexed sequence through the base neural network, and demultiplexing to predict the original token in each position of each sequence.
In some embodiments, the training phase may include pretraining the neural network after the warmup step. The pretraining may include using a masked language modeling objective. This may include masking certain tokens in the inputs (i.e. replacing them with a specific [MASK] token) and using the model to predict the masked tokens from the context of other unmasked tokens.
In some embodiments, the training phase may include finetuning the neural network after pretraining. The finetuning may include training on a specific downstream task. This may include training the model to perform a specific task (like sentiment analysis) using human annotated data.
Here, techniques were explored to obtain pretrained language models with data multiplexing, with a focus on masked LMs like BERT and ELECTRA. The disclosed models can process multiple inputs (2-10) in parallel with a forward pass over a single instance and can be fine-tuned for any downstream task as compared to those base models. Importantly, MUX-PLMs do not require fine-tuning or apriori access to task-specific data, in contrast to other methods like pruning.
The disclosed approach involves a three-stage training procedure including 1) a retrieval warmup, 2) multiplexed pretraining and 3) finetuning on downstream tasks. See
In this example, a demultiplexing module (see
The module is initialized with N key vectors which are used to demultiplex the transformed multiplexed representations (hMUX). The keys are concatenated with hMUX and are processed with an MLP to generate the demultiplexed output representations (h1, . . . , h4).
This helps improve the inference throughput, greatly improves pre-training and consequentially leads to more performant fine-tuned models. Finally, an attention-based multiplexing approach (see
In the attention-based multiplexing approach, the multiplexing module first generates contextual representations for x1, . . . , x4 with a transformer layer, then a hadamard product is done between the contextual representations and the corresponding multivariate gaussian to generate instance-conditioned representations. Then, the multiplexed representations are generated with another Transformer layer, by attending across the instances for all the positions in the sequence.
The example models (MUX-BERT and MUX-ELECTRA) were evaluated on several downstream sequence classification tasks from the GLUE benchmark as well as token classification tasks like named entity recognition and part-of-speech tagging. The exemplary models achieve close to the state-of-the-art scores that standard BERT and ELECTRA models obtain while attaining a multi-fold throughput increase. For instance, MUX-BERT can get a 4.9× speedup over BERT while only being 4 points and 2 points worse in scores for GLUE and token classification tasks, respectively. The various versions of the multiplexed models were compared along the accuracy-efficiency Pareto front to demonstrate the flexibility of the pre-trained MUX models, depending on the downstream application. Finally, several ablation studies were performed, and the internal representations of MUX-BERT were analyzed to provide more insight into data multiplexing in language models.
The multiplexer model (MUX) combines an ordered set of multiple inputs—X1:N=(x1, . . . , xN) into a single superimposed representation (xMUX). If xi∈d, the multiplexer is a transformation (MUX:
N×d→
d) such that xMUX=MUX(X1:N). To ensure MUX is an order-preserving transformation, DataMUX samples a vector (vi∈
d) from a standard multivariate gaussian and applies the hadamard product (element-wise multiplication) with the corresponding input representation (xi) before summing up vectors for all positions.
The model processes the multiplexed representation and emits a multiplexed hidden state—hMUX=M(xMUX). To multiplex Transformers' sequenced inputs (xi=(x1i, . . . , xLi)) of length L, apply the same vi to all L positions of instance i.
For the demultiplexer, a prediction needs to be made for each instance in X1:N. In some embodiments, this can be done by separating the superimposed output (hMUX) into N output representations corresponding to the input (h1, . . . , hN).
The vector pi∈d is dynamically generated for each instance (i) with the help of a prefix that is added to the input, and re-used for all positions in the instance. They add a prefixi to xi, represented by the following pattern, where ∈i is a token that is unused by the model, and pi is set to be the output corresponding to token i in the prefix.
Modern NLP systems overwhelmingly rely on pre-trained models. Therefore, an exemplary PT-DataMUX which applies DataMUX during pre-training (both for BERT and ELECTRA) to yield MUX-BERT and MUX-ELECTRA models. In this example, the models were trained in three stages (see
Contextual multiplexer. The multiplexer used in DataMUX multiplexes tokens independent of 1) tokens in the same position in other instances and 2) other tokens in the instance, which could lead to suboptimal representations. Therefore, in PT-DataMUX, a contextual multiplexing scheme is explored that cleverly aggregates context both from tokens in the same instance and tokens in the same position of other instances (see
PT-DataMUX then generates multiplexed representations, xMUX, by applying another transformer layer TRANSinst across encoded representations from N instances at each position from 1 to L. This can be achieved by, e.g., transposing gctx and applying TRANSinst.
RSA Demultiplexer. The demultiplexer in DataMUX requires a prefix whose length scales linearly with the number of instances (N), thus reducing the effective context length during pre-training, which impacts performance. Furthermore, it decreases throughput during inference for large N because the model needs to process an extra prefix of length N for each of the N instances.
To address these issues, PT-DataMUX draws inspiration from the RSA cryptosystem to randomly initialize and learn N (private) key vectors (k1, . . . , kN, ki∈d) which are keys that can be used to demultiplex the output representation (see
Akin to RSA, vi and ki can be treated as the keys for multiplexing (encryption) and demultiplexing (decryption), while ensuring that unlike DataMUX the input sequence length does not change and thereby leading to an improvement in throughput. Importantly, this architecture ensures that the keys better transfer across the different stages of training as they are no longer conditioned on the input instances.
The pre-training hyperparameters used in this example are described in Table I, below.
All pre-training related hyper-parameters are reported in Table I. Generally, the HuggingFace Transformers implementations were used for BERT and ELECTRA models. All pre-training experiments were run on 8 A100 GPUs with distributed training. A small hyper-parameter search over two learning rates. All pre-trained models are primed with the token retrieval task described herein. The models were trained on the Wikipedia and Bookscorpus datasets for up to 10000 training steps with a learning rate of 1e-4, and with a sequence length of.
For MUX-ELECTRA models, a generator as in the original ELECTRA work was not trained, but rather only a uniform-random token replacement was used. This is similar to what was used in ablations in ELECTRA. The generator randomly replaces 15% of tokens in the input with other tokens in the vocabulary.
All the fine-tuning related hyper-parameters are reported in Table II, below. A small hyper-parameter search was run on the learning rate, batch size and number of training steps for different tasks. All models were trained with half-precision. Numbers are reported on the validation split. For GLUE tasks, the default metrics in Wang et al. (2018) were used, and F1 was used for the token-level tasks. All fine-tuning experiments were trained on 1 V100 GPU.
In this example, all nine datasets from the GLUE benchmark were used, which are CoLA, SST-2, MRPC, (qqp), STS-B, MNLI, QNLI, RTE, and WNLI. The exemplary approach was also evaluated on token classification tasks such as named entity recognition (NER) and POS tagging. Here, the average over WNLI and CoLA are not reported, as these are the two smallest tasks in GLUE and high variance was observed across different seeds. Scores for all tasks in the appendix.
Variance across seeds All models were evaluated across 5 different seeds as variance was seen for smaller datasets. This variance is caused by the order in which instances were multiplexed in the dataset. Therefore, what is reported is both the average and maximum scores across different seeds in Table III, below, to understand the importance of cleverly sampling the multiplexed instances. The average across the seeds are reported for all other results.
Specifically, this example experiments with two different models pre-trained as described herein—ELECTRA (Clark et al., 2020) and BERT (Devlin et al., 2019) and present pre-trained multiplexed models for both these models for N=2, 5 and 10. They are compare against DataMUX multiplexed models and against baseline ELECTRA and BERT models across three different model configurations (small, base, and large). A random generator was used to train the MUX-ELECTRA models, presented as an ablation in Clark et al. (2020), instead of using a smaller masked LM due to compute limitations.
It is observed in Table III, that the method PT-DataMUX outperforms DataMUX for both ELECTRA and BERT on all values of multiplexing N, with improvements between 12 and 20 points on GLUE and token-classification tasks respectively. This shows the benefit of pre-training for data multiplexing, as opposed to using randomly initialized Transformers which are fine-tuned. It is noticed that the improvement of PT-DataMUX over DataMUX is consistent across all values of N. Furthermore, PT-DataMUX's efficient RSA inspired demultiplexing method allows it to achieve faster throughput than DataMUX, increasing it by over 16% for N=10.
DataMUX provides a significant boost in throughput (N times faster) when compared to standard performance of the base or backbone neural network, without a significant performance loss. For example, for ELECTRA (N=2), PT-DataMUX is within 0.4 points better and only 0.3 points worse than standard for GLUE and TOKEN, while being 2× faster. Similarly, for BERT (N=2), PT-DataMUX is within 3 and 0.6 points of Standard for GLUE and TOKEN, while being significantly faster. It is also observed that as N increases, PT-DataMUX's throughput is significantly better, but naturally the gap to standard is larger. This is because a large N implies that PT-DataMUX has to parallel process more instances, thus having to share its parameters and activations with a larger number of instances, which improves throughput and reduces performance. For example, the gap between Standard and PT-DataMUX ELECTRA on TOKEN for N=2 is 0.2 points and increases to 3.5 points for N=10, which shows that N serves as a parameter to control the performance-throughput tradeoff. N in PT-DataMUX allows the user to find the fastest model for a certain performance threshold. The results show that PT-DataMUX works both with BERT and ELECTRA, and we see similar trends and performance for different values of N.
In can be shows that the disclosed multiplexing techniques work on a host of model sizes and results are reported for PT-DataMUX BERT on three model sizes (small, base, and large) for N=2 (see Table IV, below).
Similar results were seen at larger N. PT-DataMUX's performance is close to that of standard for all model sizes while having a significantly better throughput (In Table IV, ≈2×). For example, the gap is less than 0.7 points for TOKEN tasks and 2.9 points for GLUE. Multiplexing works effectively on all the model sizes, with the drops with respect to Standard being 1.6 and 1.7 points on GLUE for Small and Large respectively. In Table IV, PT-DataMUX's throughput is always ≈2× that of Standard, which shows that a spectrum of PT-DataMUX model sizes can be multiplexed during pre-training without losing much performance and with significantly higher throughput.
Pre-trained models typically have a performance-computational efficiency trade-off, with larger models having better performance but worse computational efficiency. PT-DataMUX offers a similar trade-off, with large N leading to better throughput but lower performance. To understand this trade-off, the performance and throughput of Standard (N=1) and PT-DataMUX (N=2 and N=5) were plotted for different model sizes (L=Large, B=Base, S=Small) and the pareto-optimality envelope was marked (solid black line). See
Ensembling PT-DataMUX. When multiplexing pre-trained models, N different instances can be fed to PT-DataMUX instead of a single instance, which leads to significantly improved throughput. Here, an alternate setting is considered where the same instance is fed N times and an ensemble is built by averaging the N class logits and making a single prediction. The Base models for N=2, 5, 10 for both BERT and ELECTRA are used, and it is compared to PT-DataMUX which does not use ensembling. See Table V, below.
The ensemble model does significantly better than the non-ensemble variant on both MNLI and QQP for all the values of N considered (e.g., 1.6 and 0.9 points on N=5 BERT for the two tasks). Further, it is noted that the improvement over the non-ensemble variant (A) is better for higher N, potentially because the ensemble size is larger as a result of logits being averaged over more samples. This result shows that the non-ensemble variant is faster but performs slightly worse, while the ensemble variant performs better but is slower. A spectrum of models lie between these two extremes, where only a fraction of the N muxed representations can be ensembled, allowing users to pick where in the performance-speed trade-off they want their model to lie.
Ablation analysis. Multiplexing and demultiplexing components of PT-DataMUX were analyzed, and the results were reported in Table VI, below. Two variants are considered, one which uses prefix demultiplexing proposed for DataMUX instead of the proposed RSA-MUX and another which uses Contextual multiplexing instead of Non-contextual.
Variant 1 which uses prefix demultiplexing performs worse than PT-DataMUX, other than for N=1. In fact, Variant 1 does not converge for TOKEN tasks for N=2 and N=5 and performs 1.7 and 1.2 points worse on GLUE when compared to PT-DataMUX. This shows that the RSA-inspired demultiplexer performs better than that of the DataMUX variant.
Variant 2 uses contextual multiplexing which takes into account other tokens present in the instance and also tokens present in the same position of other instances. Across the board, it is noted that it performs better than PT-DataMUX for TOKEN tasks but performs worse for GLUE tasks. The improvement it has over PT-DataMUX on TOKEN for N=10 is over 1.7 points. We believe that contextual multiplexing's better performance in TOKEN is because the model needs to make a prediction for every single position in the instance, which requires it to efficiently multiplex all token positions in the output. On the contrary, for GLUE tasks, the model needs to make a prediction only for the [CLS] token, for which PT-DataMUX's multiplexing suffices.
Muxology: Analyzing hidden representations of multiplexed models. To understand the nature of representations being learned by pre-trained MUX-BERT models, the activations and attention weights were analyzed. Specifically, the absolute value of activations and entropy of the attention distribution across all the layers of the Transformer encoder were noted, averaged over the evaluation split of WikiText-103. See
Activation norms spike for multiplexed models in the last layer.
Entropy of the attention weights of multiplexed models is lower than that of non-multiplexed models for higher layers.
PT-DataMUX can parallel process multiple instances (N), and here it is utilized during inference by sampling N instances uniformly at random from the evaluation set. But other sophisticated data-sampling strategies can exist, for example, clustering similar instances and processing them together or considering instances which have lowest word-overlap. In this section, we explore the effect of composition of N instances on the performance of PT-DataMUX. For each model variant, this example considers 5 random seeds which can be viewed as lottery tickets. Since the random seed controls the composition of N instances, we measure the difference between the best performing ticket and the worst performing ticket and average the performance for all the GLUE tasks (see Table VII, below).
It is noted that the difference (A) is consistently greater than 1 point for all values of N for both ELECTRA and BERT, which shows that there is a significant performance difference between the best and worst ticket. This shows the importance of the composition of N instances, and it is believed that the improved data sampling strategy both during inference and training can lead to better performance.
Finally, model compression for high throughput transformers can be used in conjunction with the above. For example, in some embodiments, the method may include compressing at least one transformer of a neural network model located between the multiplexing layer and a demultiplexing layer. The compression may occur via, e.g., the well-understood concepts of network pruning and/or distillation.
In some embodiments, the method may include predicting, using a task accuracy model and a throughput model, parameters that improve throughput and meet a given accuracy budget. The task accuracy model is generally configured to estimate accuracy for a range of multiplexer widths (N) and pruning sparsities (s). The throughput model is generally configured to estimate throughput for a range of multiplexer widths and pruning sparsities. Starting from a given accuracy budget, the two models can be used together do find a target (N, s) that is may be an optimal set of values.
Large language models (LLMs) have achieved state-of-the-art performance across various NLP tasks and resulted in impressive user-facing demonstrations such as ChatGPT. However, their large size necessitates the use of enormous amounts of compute and memory at inference time, which limits their widespread use. In addition to the multiplexing techniques disclosed herein, another type of technique that has been explored for reducing the cost of model inference is model compression including network pruning, quantization, and/or knowledge distillation.
While both types of methods leverage the over parameterization effect in modern deep neural networks to improve the throughput-to-compute cost ratio, the manner in which they do so is different. Model compression aims at reducing the number of parameters in the model, hence reducing the overall compute cost (denominator) to improve the ratio. Data multiplexing, on the other hand, compresses multiple inputs into one to improve throughput (numerator) while keeping the model size fixed.
Thus, it may be feasible to combine the two methods. However, there are two challenges to this hypothesis.
The first is that both model compression and data multiplexing aim at trading a small accuracy loss for large throughput improvement. Intuitively, the combination may incur an accuracy loss larger than either method and it is not clear how they interact with each other when combining them together. A research question is how to combine the two methods such that the combination achieves better throughput than each type of method individually, given any accuracy loss budget or accuracy threshold. The second challenge is to efficiently find the best parameters pair (N, s) where N is the width of the data multiplexing and s is the sparsity of the model compression method. Training and testing with each parameter combination is costly and time-consuming. A research question is how to automatically find the best parameters without additional training and testing.
To address the first research question, disclosed is PruMUX, a combination of model compression and data multiplexing. The method is simple and consists of three phases—multiplexed model pre-training, task-specific fine-tuning, and task-specific model compression. This implementation makes use of CoFi (Xia et al., 2022), a state-of-the-art model compression method which includes intermediate knowledge distillation steps that help minimize accuracy hits and DataMUX, which performs vector-based input multiplexing over instances. The results over four datasets (MNLI, QNLI, QQP and SST-2) demonstrate that PruMUX achieves significantly higher throughput over CoFi and DataMUX individually for a large range of accuracy thresholds.
To address the second research question, disclosed is Auto-PruMUX, a meta-model to automatically predict the high-performance parameter combinations for a desired accuracy on a task without running experiments. Linear and cubic interpolation models are used over a few sparse data points to predict the throughput and accuracy of a Pru-MUX model based on sparsity and multiplexing factor. This has shown promise in modeling the trade-offs accurately and Auto-PruMUX can find high-performance combinations of known parameters as well as unknown parameters, providing a practical method for choosing an high-performance Pru-MUX model for a downstream task.
PruMUX can achieve better throughput than model compression and data multiplexing individually is that they improve the throughput of a model in two different dimensions: reducing the latency of an inference and compressing multiple inferences. In addition, both methods lead to non-linear drops in model accuracy at some points. PruMUX can achieve high throughput while avoiding each method's limitations.
CoFi is a state-of-the-art model compression method (Xia et al., 2022) that uses distillation and structured pruning to jointly prune a Transformer network. Its key idea is to distill the knowledge from the base model into the pruned model during training. A layer-wise distillation approach is used to guide the pruning from the teacher model, i.e., dense model, to the student model, i.e., pruned model, with a loss defined as:
where Hsm(i) and Hti are hidden representations of the m(i)-th feed-forward layer of the student model and i-th feed-forward layer of the teacher model. i is the teacher model's closest layer to the layer m(i) of the student model.
CoFi prunes both coarse-grained and fine-grained units of the distilled network. The coarse-grained units include multi-head attention layers, fully-connected layers, and attention heads. The fine-grained units include hidden dimensions, and intermediate dimensions of the Transformer model. Different masks are used for different pruning units and are learned via l0 regularization during training. The units with mask variables smaller than a threshold are pruned away before inference.
A key motivational question is the following: given an accuracy loss budget, can the combination of model compression and data multiplexing achieve better throughput than each method individually?
PruMUX is a method to combine the two methods, and it is shown that PruMUX achieves substantially better throughput than each method alone for various accuracy thresholds in our experimental results.
PruMUX is a method to convert any Transformer into a high throughput model, capable of compressing multiple inference inputs into a single input and executing it at a low latency. For multiplexing, PruMUX uses DataMUX (including, e.g., PT-DataMUX), which appends a multiplexer and demultiplexer as described herein. With width N, the inference throughput of the Transformer can be improved by a factor of up to N, as each multiplexed input takes the same amount of computing resources as performing inference over a single input.
For model compression, PruMUX can use any method such as network pruning, distillation, or a combination of the two (such as CoFi). The goal is to substantially reduce the latency of processing an inference. For our experiments, PruMUX uses CoFi as the model compression method.
Training a model with PruMUX consists of three phases as shown in
Phase 1: Priming the multiplexed model with the token retrieval objective. The multiplexed transformer model is first primed with a token retrieval task as disclosed herein. Introducing this “retrieval warm-up” self-supervised objective (shown below) appears to be high significant for improving the performance of multiplexed models.
Phase 2: Pre-training and fine-tuning multiplexed models. The multiplexed models from the previous phase are then pre-trained on large-scale text corpora with the masked language modeling (MLM) objective. The pre-trained multiplexed models are then fine-tuned on downstream tasks to yield task-specific multiplexed models.
Phase 3: Model compression. Finally, a model compression technique (here, CoFi) is used to jointly prune coarse-grained and fine-grained units in the multiplexed Transformer model. The coarse-grained units include entire attention heads, attention layers, and fully connected layers. The fine-grained units include hidden dimensions and intermediate dimensions of the Transformer model. The demultiplexer's input dimension is pruned in order to match the pruned hidden dimension of the Transformer model. During the pruning process, CoFi uses knowledge distillation to transfer knowledge from the teacher model, i.e., the task-specific multiplexed model, to the pruned model.
As understood in the art, model compression reduces the number of model parameters with minimal loss in task performance. A well-studied method is network pruning, which removes unimportant connections or weights from a network with minimal or no accuracy loss. Unstructured pruning does not impose any constraints on the locations of non-zero weights. The resulting network can achieve high sparsity but may not run efficiently on common hardware such as GPUs. Structured pruning produces structured sparse matrices that can take better advantage of the parallelism in existing hardware, but its sparsity is relatively lower than the unstructured pruning method for the same accuracy loss budget. Structured pruning has been applied to transformers to improve inference throughput.
Distillation compresses a model by transferring knowledge from a large teacher model to a small student model. General distillation for Transformer models learns from unlabeled corpus. Task specific distillation for Transformer models learns on task-specific data combines the two distillation methods to improve performance. Xia et al., 2022 proposes structured pruning with distillation objective to reduce the Transformer parameters by up to 95% and achieve over 10× speedups with small accuracy drops.
Implementation Details. The multiplexed BERT-base models are pretrained with the standard BERT pre-training recipe with the masked language modeling objective for N=2, 5, 10 on Wikipedia and BooksCorpus datasets. The multiplexed model is primed before pre-training with the token-retrieval task on the Wikipedia and Bookscorpus datasets. The pre-trained multiplexed models are then trained on the four largest GLUE Tasks—MNLI, QNLI, QQP, and SST-2. The CoFi structured pruning objective is then used to get a pruned multiplexed model on each task dataset. The hyperparameters used for the training process are shown in Table VIII, below. A single run was performed to train the model for each setting, i.e., task, multiplexer width N, model sparsity s, following the training process.
To answer the question of whether the PruMUX method can achieve a higher throughput than either CoFi or DataMUX alone, given an accuracy threshold, a PruMUXed BERT-base model was compared to three baselines:
As is seen, PruMUX achieves higher throughput than either CoFi or DataMUX individually in all cases starting at various accuracy thresholds:
For MNLI, with the accuracy thresholds from 80% to 74%, PruMUX achieves 5.5-23.0× throughput improvement over the BERT-base model, whereas CoFi improves by 4.0-13.3× and DataMUX by 2.0-4.9×.
For QNLI, with the accuracy thresholds from 87% to 82%, PruMUX achieves 4.3-18.6× improvement, whereas CoFi improves by 3.9-7.5× and DataMUX by 2.0-9.8×.
For QQP, with the accuracy thresholds from 89% to 86%, PruMUX achieves throughput improvement over BERT-base by 5.5-24.2×, whereas CoFi improves by 7.6-11.7× and DataMUX by 2.0-9.8×.
For SST-2, with the accuracy thresholds from 86.5% to 83%, PruMUX improves the throughput by 8.7-23.4×, whereas CoFi improves by 4.4-12.3× and DataMUX by 4.9-10.1×.
The results also confirm the intuition that PruMUX with (N, s) incurs an accuracy loss, loosely speaking, close to the sum of the accuracy loss of DataMUX with N and that of CoFi with s. In general, PruMUX can achieve substantial throughput improvement when there is a decent accuracy loss budget.
The results above require training and testing PruMUX with all 18 parameter pairs (N,s) where N=2, 5, 10 and s=0.50, 0.60, 0.70, 0.80, 0.90, and 0.95. With all the testing results, for any accuracy threshold, one should quickly find the best parameters.
Exhaustive tests are impractical. First, for each N, pre-training a DataMUX model with multiplexing width N is time-consuming. Second, given each pre-trained model with multiplexer width N, different sparsities s provide different throughput and accuracy trade-offs. In order to find the sparsity s with the highest throughput given an accuracy budget, one has to train the model for all possible sparsities. The total training time for the sparsities at the granularity of 0.05 for each N takes over a thousand GPU hours on commodity GPUs, for a small original BERT-base model.
A key question is whether one can automatically find a high-throughput (N, s) with minimal number of PruMUX experiments.
To address that question above, Auto-PruMUX is proposed, a method to search for best (N, s) parameters, to help practitioners balance the performance vs throughput trade-off. The research question is: Suppose one has some experimental data of DataMUX with a set of parameter N's and some of CoFi with a set of s's, how can the high-performance parameter (N, s) be found without training PruMUX models?
The approach used was to develop performance models for the accuracy and throughput of PruMUX. To do this, it was required to first train PruMUX models for a large set of (N, s) combinations and measure both the accuracy and the throughput improvement. One can then use this data to fit a throughput model and an accuracy model to predict throughput and accuracy respectively given (N, s) parameters.
It is first discussed how to fit the accuracy and throughput models with a few sparse data points. Given that these examples are working with a limited set of data points, it is opted to use a simple class of interpolation models for modeling PruMUX accuracy and throughput. It is then outlined how to leverage these models to predict (N, s) parameters, given an accuracy budget. The PruMUX models are then trained with the predicted configurations to demonstrate Auto-PruMUX's ability to predict better parameters without additional training.
Task Accuracy Model. Linear interpolation is used for our task accuracy model.
Each term is a linear combination of data multiplexer width and model sparsity.
The model is fitted on the gathered data of model task accuracy at different multiplexer width and sparsity.
where N and s are the range of N and s values used to fit the model.
Cubic interpolation is used on throughput data (other approaches (e.g., linear regression, etc.) may be used, although results may not be as improved as with cubic interpolation).
Each term is defined as a cubic combination of N and s.
The throughput model is fit on collected data points and their throughput.
where N and s are the range of N and s values used to fit the model.
The models (fA(N,s) and fT(N,s)) are used to search for (N,s) parameters that maximize ζf defined below
Intuitively, ζf tries to tradeoff task performance and throughput, given a performance budget ζ with the goal of maximizing the throughput. g(x) provides a mechanism for a strict accuracy threshold—i.e. a model that does not meet the minimum required accuracy will have ζf=0.
The goal is to evaluate the task accuracy model, the throughput model, and parameter prediction performance. To evaluate the performance models on the accuracy and throughput data, performance data was collected for different (N, s) parameters on each task. Leave-one-out cross validation was used to fit the performance models using part of the data and evaluate how well they perform on the rest of the data not used in model fitting. The fraction MA of valid accuracy predictions, i.e., with error falling within Δξ=2% from real accuracy, and the fraction MT of valid throughput predictions, i.e., with error within 30% of real throughput, are shown in Table IX, below.
It is noted that across different tasks, the accuracy model and throughput model are accurate across a broad set of parameter combinations.
Predicting parameters without additional training. The utility of searching parameters with the throughput and accuracy models is demonstrated by fitting the models on the following subset of parameters—(N, s)∀N∈1, 2, 5, 10 ∀s∈0.00, 0.60, 0.95, and then using the fitted models to predict from a larger set of parameters—(N,s)∀N∈1, 2, 5, 10, ∀s∈0.00, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95. The search is conditioned given an accuracy budget as defined previously. It is noted that making predictions with a finer granularity of N, i.e., N=2, 3, . . . , 10 with ΔN=1, could potentially improve the predicted throughput further, that was not done in this example.
Here, Auto-PruMUX was leveraged to make parameter predictions on the larger set of parameters defined earlier. Table X, below, shows parameter predictions made by Auto-PruMUX for different accuracy budgets on the MNLI task. Auto-PruMUX can generalize to parameter combinations it was not fit on and for different accuracy thresholds, predicts faster parameter configurations that are not part of the training data used to fit the Auto-PruMUX models.
For instance, Auto-PruMUX suggests that the (2, 0.80) configuration for MNLI would lead to a higher throughput increase for an accuracy budget of 77% (see row 2). This prediction was verified by training the PruMUX model with that configuration and getting an accuracy of 79.8 and a throughput improvement of 7.9×. This shows Auto-PruMUX is able to generate better configurations without additional training.
Various modifications may be made to the systems, methods, apparatus, mechanisms, techniques and portions thereof described herein with respect to the various figures, such modifications being contemplated as being within the scope of the invention. For example, while a specific order of steps or arrangement of functional elements is presented in the various embodiments described herein, various other orders/arrangements of steps or functional elements may be utilized within the context of the various embodiments. Further, while modifications to embodiments may be discussed individually, various embodiments may use multiple modifications contemporaneously or in sequence, compound modifications and the like.
Although various embodiments which incorporate the teachings of the present invention have been shown and described in detail herein, those skilled in the art can readily devise many other varied embodiments that still incorporate these teachings. Thus, while the foregoing is directed to various embodiments of the present invention, other and further embodiments of the invention may be devised without departing from the basic scope thereof. As such, the appropriate scope of the invention is to be determined according to the claims.
This application claims priority to U.S. Provisional Patent Application 63/309,903, filed Feb. 14, 2022, the entire contents of which are incorporated by reference herein.
Filing Document | Filing Date | Country | Kind |
---|---|---|---|
PCT/US2023/013018 | 2/14/2023 | WO |
Number | Date | Country | |
---|---|---|---|
63309903 | Feb 2022 | US |