The present disclosure relates to processing data in a machine learning computer. Particularly, but not exclusively, this disclosure relates to processing of neural networks using mixed-precision numerical formats.
Deep neural networks are machine intelligence models used to perform a wide variety of different tasks in different fields such as computer vision (such as object recognition) and natural language processing (such as machine translation, natural language generation).
Each node 102 represents a function of its one or more inputs as received on its input edge or edges, with the result of this function being the output(s) provided on the output edge or edges. These results are sometimes referred to as activations. Each function is parameterised by one or more respective parameters (sometimes referred to as weights, though they need not necessarily be multiplicative weights). In general the functions represented by the different nodes 102 may be different forms of function and/or may be parameterised by different parameters. In deep neural network architectures, nodes are arranged into layers, with nodes of each layer receiving tensors on the output edges of the previous layer, and communicating their own outputs to nodes of the next layer in the network.
Further, the function at each node is parameterised by one or more respective parameters, e.g. weights 151, which are applied to the input activations to compute the input to the activation function 153, which generates the output activation.
The activation function 153 is configured to receive weighted input values and generate an output value based on the activation function. The activation function is typically attached to each node in the network and determines whether it should be activated (“fired”) or not, based on whether each node's input is relevant for the model's prediction. Certain activation functions, such as sigmoid or tanh, also help normalise the output of each node to a range, for example between 1 and 0 or between −1 and 1. Other activation functions, such as a rectified linear unit (ReLU), do not normalise the output.
In a standard deep neural network architecture, each node of a given layer is connected via a link 104 to every node of a subsequent layer. Networks with this all-to-all connectivity may be referred to as ‘fully connected’. In a convolutional neural network however, each node of a layer applies a ‘filter’ of weights (which may also be referred to as a kernel) in a sliding window to an input tensor to determine a weighted input to a node 102, where the filter only applies to a subset of input values to the given layer at a time. The subset of inputs that the filter ‘sees’ at a time may be referred to as the receptive field. Other common neural network architectures include recurrent neural networks and transformer architectures. Various implementations of these architectures exist in the art and will not be described further herein.
As described above, the output of each node or ‘neuron’ of a neural network depends on one or more parameters or weights applied to the set of inputs to that node. To train a neural network, the parameters at each layer are updated according to a learning scheme, to optimise a training goal. For example, where a goal is to train a network to identify object classes present in an input image, the output layer may be configured to output an indicator for a predicted class from among a set of possible classes, and the training goal may be to maximise an accuracy of the neural network's prediction for a set of input images where the class of the objects within the images are known. In this context, deep neural networks are obtained by stacking multiple layers. The strength of these multi-layer architectures is that successive layers have the possibility of reusing features that have been built by the first layers, with a reuse of features that corresponds to an efficient implementation.
Learning is generally based on the iterative update of the parameters of each of the layers, typically through backpropagation. In practice, backpropagation based on gradient descent computes the gradient of the loss with respect to the output of the last layer, and then this gradient is backpropagated using the chain rule of calculus. With backpropagation, each layer receives the gradient of the loss with respect to its output, and uses this quantity to derive the gradient of the loss with respect to the parameters, the weights of that particular layer. These quantities are then used to update the corresponding weights.
Gradient descent methods are highly effective and widely used to train neural networks. However, a common problem when training very large deep learning models comprising up to millions or even billions of weights is that the memory required to store the weights, activations (where activations are the output values of each node) and gradients at each layer is significant.
One way to reduce the required memory usage when training a deep learning model is to choose a representation for weights, activations and/or gradients of the network such that each value occupies fewer bits of memory.
In computing, bit sequences of predefined sizes are used to represent numbers. The particular representation of the bit sequence determines how a bit sequence is interpreted. The general form of representation is the floating-point representation, which is often used to approximately represent real numbers. The floating-point representation comprises three separate components, i.e. a sign bit s∈{0,1}, an m-bit mantissa with bits di, i=1, . . . , m, and an e-bit exponent p, 0≤p<2e. In the single-precision (i.e. 32-bit) floating-point representation according to the IEEE 754 standard, the exponent consists of 8 bits, and the mantissa consists of 23 bits. In the half-precision (i.e. 16-bit) floating-point representation, the exponent consists of e=5 bits, and the mantissa consists of m=10 bits. In most cases, a floating-point number is given from these three components by the following formula:
The displayed exponent bias b allows to offset the representation of the exponent. This exponent bias is commonly given by b=2e−1 −1 and is dependent on the number of bits e used to represent the exponent for the given floating-point format. In the single-precision representation, the exponent bias is equal to 27 −1=127. In the half-precision format, the exponent bias is equal to 24 −1=15.
As shown in the above formula, the representation of the mantissa typically relies on an implicit bit, which is derived from the exponent. In the case where the exponent bit sequence consists of anything other than all zeros or all ones, the implicit bit is equal to 1 and the number is known as a “norm”. In this case, the floating-point number is given by:
In the case that the exponent bit sequence consists of all zeros, the implicit bit is equal to 0 and the number is known as a “denorm”. In this case, the floating-point number is given by:
The denorms are useful, since they allow smaller numbers to be represented than would otherwise be representable by the limited number of exponent bits.
The other circumstance—in which the exponent bit sequence consists of all ones—may be used to represent special cases, e.g. ±infinity or NaN (not a number). NaN is a numeric data type value representing an undefined or unrepresentable value. The presence of a NaN in the results of a calculation is often taken to signal an exception.
Another form of representation is the integer representation. The integer may be signed, in which case a single bit of the bit sequence is used to represent the sign of the number, with the remaining bits of the bit sequence used to represent the magnitude of the number. Another common representation for signed integers is two's complement representation. Alternatively, the integer may be unsigned, in which all of the bits of the bit sequence are used to represent the magnitude of the number.
Floating-point representation is used to represent numbers in most current implementations of neural network processing.
A standard floating-point representation FP32, known as single-precision floating-point format, uses 32 bits in memory and can represent a very large range of numbers (from ˜10Δ38 to ˜1038). Lower-precision formats using 16 bits (FP16) or even 8 bits (FP8) represent a significant reduction in memory usage and computational cost, particularly when used for representing a deep learning model with up to millions or billions of parameters. However, using fewer exponent and mantissa bits to represent a number leads to a reduction in the range and/or precision of representable values. Lower-precision formats can lead to two possible problems when applying arithmetic operations to numbers stored in this format: numerical underflow and overflow. Underflow occurs when the absolute value of a number is too small to be represented in the chosen number format. Where an arithmetic operation gives a result which is too small to be represented in the chosen floating-point format, leading to underflow, the number will instead be represented as zero. Numerical overflow occurs when the absolute value of a number is too large to represent in the chosen format. In this case, the number may be represented as positive or negative infinity, or be saturated (‘clipped’) to the maximum positive ort negative number that can be represented by the chosen format. As mentioned above, for FP32 numerical overflow only occurs for numbers with absolute values greater than ˜1038. However, for FP8, numerical overflow occurs at much lower absolute values, and overflow could occur when training a neural network, if weights, activations or gradients grow sufficiently large.
Some methods of offsetting the effects of overflow and underflow for lower-precision numerical formats use the concept of exponent bias.
Another way of reducing instances of numerical underflow and overflow in lower-precision formats is to apply a scaling factor to variables which are likely to take very small or very large values and are therefore prone to underflow or overflow. Applying a scaling factor to a loss function, which in turn scales the gradients of the loss function used in training the network, is referred to as loss scaling. Values of weights and activations may also be scaled by adjusting the exponent bias term applied of the floating-point representations. A challenge is to choose the scale of the gradients so as to minimise both underflow and overflow and maximise the accuracy of the representations.
A first aspect disclosed herein provides a computer-implemented method of training, based on a set of training data, a multi-layer neural network comprising a set of network weights, the method comprising: processing the training data in respective forward and backward passes through a sequence of layers of the network, the forward pass comprising computing a set of activations by applying an activation function in dependence on the network weights and training data, and the backward pass comprising: computing gradients of a pre-determined loss function with respect to the network weights and/or computing gradients of the pre-determined loss function with respect to the computed activations of the network, wherein an adjustment parameter is applied to at least a subset of values in the neural network, the values comprising at least one of: the network weights, the activations computed in the forward pass, the gradients with respect to activations computed in the backward pass, and the gradients with respect to weights computed in the backward pass; updating the network weights in dependence on the computed gradients with respect to the weights; computing a proportion of the subset of values falling above a predefined threshold; and updating the adjustment parameter applied to the subset of machine learning parameters in dependence on the computed proportion.
It should be noted both the terms ‘signal’ and ‘value’ are used herein to refer collectively to the weights, activations and gradients of the network.
For a better understanding of the present disclosure, and to show how embodiments of the same may be carried into effect, reference is made by way of example only to the following figures in which:
Certain factors should be considered when selecting a numerical format with which to represent data, for example weights and activations of a deep learning model, including computational and communication efficiency, as well as accuracy. As described above, a standard floating-point representation FP32, known as single-precision floating-point format, uses 32 bits in memory, and can represent a very large range of numbers. Lower-precision floating-point formats occupy less space in memory than single-precision floating-point numbers. An 8-bit floating-point representation comprising 1 sign bit, 5 exponent bits and 2 mantissa bits occupies only 8 bits of memory compared to a number represented in 32-bit format, which occupies 32 bits, or 4 bytes of memory. However, 8-bit floating-point formats or FP8 formats have the cost of a narrower range and lower precision of representable values, making underflow and overflow more likely.
Deep learning models are usually trained using gradient descent methods. These are described in more detail later, but in general these use a technique known as backpropagation in which a gradient with respect to weights at a given layer of the network is determined as a function of the activations and gradients with respect to activations at the following layer of the network. For many deep learning models these functions include matrix multiplications. This may lead to underflow when many small quantities are multiplied together, for example as the result of the successive multiplication of gradients during backpropagation. Numerical underflow thus causes an issue when small gradients occur in a network, as the gradients cannot be accurately computed and propagated through the network. Numerical overflow may also occur, typically when weights or activations grow too large during training to be accurately represented in the chosen format. While overflow can in theory also occur for gradients, for example when many large gradients are multiplied together during backpropagation, in practice gradients take smaller values than weights and activations on average.
Some methods of offsetting the effects of overflow and underflow for lower-precision numerical formats use the concept of exponent bias. Standard floating-point numbers use a fixed exponent bias b in order to store the exponent of the floating-point number as an unsigned value, such that when the bias is applied the exponent can have positive or negative values. For example, standard single-precision floating-point numbers have exponent values in the range −126 to +127 once the exponent bias b has been applied. It should be noted that applying a negative bias b to the exponent of a floating-point representation shifts the representable values down, which provides an equivalent effect to multiplying the number by a scaling factor equal to 28. Therefore, it is possible to effectively change the representable range of a floating-point number either by multiplying the number by a scaling factor directly, or by adding or subtracting a bias to/from the exponent of the floating-point representation of that number. Scale factors and exponent biases may be referred to herein as adjustment parameters as their application to floating-point representations of weights, activations and gradients of a model can be used to adjust the range of representable values according to the data.
As mentioned above, one way of reducing instances of numerical underflow and overflow in lower-precision formats is to apply a scaling factor to variables which are likely to take very small or very large values and are therefore prone to underflow or overflow, for example by applying a multiplicative factor to scale gradients of the network in order to perform computations which do not result in underflow.
Note that, as described above, deep learning models may be trained by computing gradients with respect to a loss function and updating the model parameters based on the computed gradients. Therefore, applying a constant scaling factor to the loss function is equivalent to applying a constant scaling factor to the gradients of the loss function. Herein a scaling factor may be referred to as a ‘loss scaling factor’, but it should be noted that this is the same as multiplying the gradient of the loss function by the same factor.
Described below is a method of scaling gradients of a deep neural network during training in dependence on the gradient statistics to enable gradients to be stored in a lower-precision format. An overview of neural networks and gradient-based training methods will first be provided.
Machine learning models can comprise up to millions or billions of parameters and can require significant amounts of training data to provide good performance. Thus, computing resources required for machine learning models are significant, both in terms of memory for storing parameters and intermediate data, as well as computing power to carry out arithmetic operations of large numbers of variables at once. One way to reduce the computational cost of processing large amounts of data is to use a lower-precision numerical format to represent weights and activations of the network, as well as gradients of the loss function which are used to compute updates in training.
Low-precision floating-point formats have a limited range of numbers which can be represented compared with single-precision floating-point, which uses 32 bits to represent numbers spanning a range of absolute values, from 10−38 to 1038. Throughout training of a neural network, the scale of weights, activations and gradients may vary significantly such that a relatively large range of scales need to be represented.
A standard method of training a neural network using gradient descent will now be described with reference to
The goal of learning is to arrive at a set of network weights that minimise some training objective. At a final layer of the network, a prediction is output, which may depend on the task the network is designed to perform. For example, for an image processing task, where the input is an image containing an object, the network may be structured to output a predicted class of the object, given a set of possible classes. Typically, a network is trained by providing a set of training data for which the correct output is known, and defining a loss function 100 which measures the cost of using the network prediction for a given input instead of the ‘correct prediction’ corresponding to that input. The network weights may be initialized based on random value with certain statistics. However, during training the network weights are updated so as to minimise the loss function 100, i.e. to make the network predictions as close as possible to the ‘correct’ predictions.
A common optimisation scheme used to minimise the loss function is gradient descent. According to gradient descent, a gradient of the loss function may be computed with respect to the weights of the network, and each weight may be updated in the opposite direction to its respective component of the gradient, therefore ‘pushing’ the weights in the direction of minimal loss. The gradient with respect to the activations may also be computed as an intermediate step before computing the gradient of the loss function 100 with respect to the weights. Since the activations are a function of the weights, and the loss function 100 is a function of the activations, applying the chain rule allows the gradient with respect to the weights to be calculated based on the gradient with respect to the activations
Backpropagation is well-understood in the art and therefore will not be described in further detail herein.
As mentioned above, the weights may be updated so as to ‘push’ the weights in the direction of the negative gradient. In other words, a term proportional to the component of the gradient corresponding to the given weight may be subtracted from the current value of the weight as follows:
where is a learning rate controlling the size of the update. This is shown by the weight update 408 applied at each layer.
Note that the term ‘gradient’, while technically referring to a vector of partial derivatives, is used more generally herein to refer to both the gradient with respect to a weight or activation vector and the corresponding partial derivative components, computed with respect to a single weight or activation of the network. In other words, ‘gradient’ herein can refer to either individual components of a vector of partial derivatives or to the vector itself. A reference to the magnitude of gradients refers to the magnitude of individual partial derivatives of the loss function with respect to a given weight or activation.
Depending on the form of gradient descent used, the weights may be updated based on one training example at a time, or more commonly based on an aggregated gradient computed for a subset of the training examples, which may be referred to as a minibatch. In this case, an accumulation operation is applied to get an aggregated (e.g. average) gradient to be applied in the respective weight update. Each layer updates their respective weights based on the respective gradients with respect to the weights at that layer, as shown by the multiple weight updates 408 in
The techniques described below provide a way to automatically scale gradients of a neural network based on their statistics, where references to gradients of the network herein include both gradients with respect to weights and gradients with respect to activations.
Low Precision Formats
An issue with storing weights, activations and gradients in low-precision floating-point format, such as FP16 or FP8 is that the weights or activations and the gradients often take on a wide range of values. Weights and activations may grow beyond the range of numbers representable in these formats, and the magnitude of gradients may fall below the lowest representable non-zero value.
As mentioned above, floating-point numbers may be represented by a sign bit, a number of mantissa bits, and a number e of exponent bits. One example 8-bit format uses 5 exponent bits and 2 mantissa bits. This may be referred to as 1.5.2 format.
In addition to numbers falling outside the representable range of values, using a floating-point format representation means that numbers are not represented continuously. Only numbers which can be represented as a sum of powers of two (with the range of exponents dictated by the number of exponent bits) have an exact representation in the chosen floating-point format. In a simple example, where the smallest number representable by the exponent bits is 2−2=0.25 then any number between, e.g. 2 and 2.25 in decimal cannot be expressed accurately in this format and would therefore be rounded to the nearest of these two values. The resulting representation error is referred to as rounding noise. For weights, activations and/or gradients of a neural network, it is important that the format is chosen such that the corresponding numbers can be represented with low quantization noise, where quantization noise includes both saturation errors and rounding errors, i.e., limiting the loss of accuracy due to saturation and underflow and the loss of accuracy due to rounding noise.
As neural networks are trained, the weights are updated and the activations and gradients are recomputed at each of a set of training iterations. During this process, some weights and activations may grow comparatively large. If a weight or activation falls beyond the upper limit of representable numbers, there is no way to store the correct value in the given format. Therefore, a process known as saturation or ‘clipping’ may be applied.
For gradients, which are more likely than weights or activations to take on smaller values, a common problem is underflow, wherein the value of the gradient is too small to be represented accurately. This may be mitigated by applying a scaling factor to the gradients (or to the loss function from which the gradients are computed), in a process known as loss scaling. In this case, the scale of the gradients may be increased so as to effectively represent the gradients in a low-precision formats while carrying out expensive computations such as matrix multiplication and convolutions, and to scale the results back down by the same factor afterwards. However, applying too large a scaling factor may cause overflow, and therefore may cause some gradients to be clipped.
Adaptive Loss Scaling for the Backward Pass
One method of loss scaling identifies when a loss scaling factor should be increased based on when clipping events are observed. This may be referred to as ‘Backoff scaling’, as described, for example in Nvidia OpenSeq2Seq documentation, in a section titled ‘Mixed Precision Training’ (https://nvidia.github.io/OpenSeq2Seq/html/mixed-precision.html). Under this method, the loss scaling factor may be increased until a gradient becomes large enough to be clipped, at which point the loss scaling factor is ‘backed off’ to a lower level, from which it progressively increases until the next clipping event. This is based on the premise that clip events need to be avoided.
However, the inventors have recognised that neural network training has some tolerance of a small amount of clipping of gradients, and that there may not be a requirement to avoid all clipping events. They have recognised that a better performance may be obtained by reducing or increasing the loss scaling factor in dependence on the statistical properties of the gradients, for example a proportion of the gradients that fall above a certain threshold which indicates that a saturation of the upper end of the representable range has occurred. It should be noted that ‘proportion’ herein may refer to any relative count of gradients with respect to an overall set of gradients, and is not necessarily limited to a percentage. A software-implemented program will now be described which collects statistics for the gradients of the network and updates a loss scaling factor in dependence on the statistics in order to provide an optimal representation of gradients for a given format. It should be noted that the computer program may also be configured to collect statistics and adjust the format for the forward pass, which is described in more detail later.
At each loss scaling update step, the statistics of the computed gradients are computed. Each loss scaling update step may be, for example, at every hundred training iterations. The frequency at which gradient statistics and loss scaling factor updates are computed may be user-configurable. The statistics may be computed based on a set of one or more thresholds, for example, a histogram of the gradients at each layer falling into each of a set of bins defined by bin edge thresholds 422 may be determined. The statistics may be accumulated by an accumulation operation 412 summing the histograms for each layer for example. The accumulation may only aggregate the statistics for a subset of the layers of the network.
Note that the gradients may be accumulated over all layers in accumulation 412 with a single threshold 422 applied to determine the loss scaling factor, or a separate threshold 422 may be applied at each layer to determine a proportion above each threshold at each layer, before aggregating the computed proportions in the accumulation operation 412.
After the accumulated statistics are determined, a loss scaling algorithm 414 is applied to update the loss scaling factor based on the statistics. For example, where a histogram of two bins is computed, and the number of gradients in one of the bins is above some predefined proportion of the total number of gradients, then the scaling factor is reduced so as to avoid too many gradients from reaching the upper end of the representable range resulting in too many clipping events to maintain good performance. This is described in more detail with reference to
The weights are updated within an optimiser 416 based on an accumulation 410 of the gradients computed for the given layer over a minibatch of training data. The optimiser 416 applies the weight update according to a gradient update rule such as equation (1). One example optimisation algorithm used in the field of machine learning is the Adam optimiser which applies a particular type of stochastic gradient descent update. Other gradient-based optimisation algorithms are known in the art, any of which may be used to train a neural network according to the method shown in
Training iterations repeat with the same loss scaling factor until the next loss scaling factor update step.
Once it is determined that a proportion greater than f of the gradients lie above the threshold T, the loss scaling factor may be reduced by a factor s. This has the effect of shifting the distribution of gradients down, once they are scaled by this factor, such that a smaller proportion of the gradients lie above the threshold T. An algorithm may be applied which either increases the loss scaling factor at every loss scaling factor update step if it is below a threshold, or only updates the loss scaling factor after a number of consecutive update steps wherein the proportion above the threshold is below the critical fraction f.
A gradient histogram may be computed which comprises more than two bins. In the case where the gradient histogram comprises more than two bins, with bin edges {b1, b2, . . . , bn−1} and bin counts {h1, h2, . . . , hn}, then for a given threshold T, and after M consecutive optimizer steps, the loss scaling factor L is increased only if the proportion of the total count of all bins whose edges are greater than or equal to threshold T does not exceed the user defined fraction f . That is to say
The loss scaling factor is decreased otherwise.
At each training iteration, a first step 602 computes the forward pass, and at step 604 the gradients are computed in a backwards pass. The weights are updated at step 606 based on the computed gradients. A check 610 is then done on the current optimiser count to see if the current iteration is a multiple of the number of steps N defining the frequency of computing gradient statistics. If the current optimiser count is not a multiple of N then at step 608 the count is updated by 1, and the scaling count is also updated by 1, since no change of the loss scaling factor takes place at this step. If the current training iteration is a multiple of a predetermined number of iterations N defining the frequency at which gradient statistics are computed, then after updating the current optimiser count at step 614, the gradient statistics are computed at step 616. These may be computed as a histogram of gradient values falling into two or more bins, for example. The statistics may be computed for each layer of the network separately and accumulated for the entire network. A condition 618 is then applied to see if the proportion of gradients above a threshold is larger than the critical fraction f, where this critical fraction can be defined by the user. If the proportion of gradients above the threshold is larger than f, then the scaling factor is reduced at step 620 by a factor s, and the scaling count is reset to zero at step 624 to signify that the scaling factor has been updated and this is the first iteration with the new loss scaling factor. In one example, s=2, and the loss scaling factor L is halved. If the proportion of gradients above the threshold is less than or equal to the critical fraction f, then a further check 622 is performed to identify how many steps the proportion has been below the fraction, which is given by the scaling count. If the scaling count is at least M steps, then the loss scaling factor L is updated by a factor s at step 624, and the scaling count is reset at step 626 to signify that this is the first iteration with the new loss scaling factor. In the above example where s=2, the loss scaling factor is therefore doubled for every M steps in which the proportion of gradients above the threshold is less than the critical fraction f. If at step 622 the scaling count is less than M steps, then the scaling count is incremented by one at step 628, and a new training iteration begins with a forward pass 602, without any change in the loss scaling factor. In practice, it may be desirable to adjust the scaling factor up or down at every iteration in which the gradient statistics are computed. This can be achieved by setting M=N. In this case, the step 622 will not be necessary as the current scaling count will always be N=M.
Note that the factor s in the present example is applied both to the scaling up and the scaling down of the loss scaling factor. In other implementations different factors s1 and s2 may be used to scale up or scale down the loss scaling factor as required. These factors may be constant, or may be adapted over the series of iterations based on the gradient statistics or other factors.
The above-described method of scaling up gradients allows gradients of the network to be stored and processed in computations in a lower-precision representation such as FP8 or FP16 which results in improved computational efficiency when processing gradients and communicating gradients between processors for multi-processor systems. Neural networks may combine the storage of gradients in a low precision format with higher precision representations of weights and activations. Alternatively, weights and activations may also be stored in low-precision formats for processing in particular layers, such as layers with convolutions and matrix multiplications. Any subset of activations, weights and gradients of the network may be selected for storage in a low-precision format. References herein to ‘a subset of activations, weights and gradients’ includes subsets containing all members of one group, such as all activations or all weights, as well as subsets containing values from different groups, such as all activations and all weights from the first layer.
In addition to scaling the loss, the representation of the gradients in the backward pass may be adjusted by selecting an appropriate exponent bias, which offsets the exponent value in the chosen floating point format by a fixed amount, which is equivalent to applying a fixed multiplicative factor.
Mathematical details of an example implementation of automatic loss scaling for an L-layer neural network model M will now be provided. This implementation collects two histograms for gradients with respect to weights and gradients with activations, respectively, and uses an aggregation of the two histograms to determine whether to increase or decrease the scale factor.
The loss estimated over a micro-batch (i.e. small subset of the overall training data) of size E is given by:
where each layer, l, for 0≤I≤L, with parameters θl, such that θl=ϕ if the layer is parameterless, is defined as the mapping of its input, and model parameters Θ=Ui=DL−1θl, where x is the input to the network. The composition of the first l+1 layers in the model is denoted by l=gl·gl−1· . . . g0.
The following mathematical description is generally applicable to different configurations of neural network models, for example different optimisers, hyperparameter values, etc. In the present example, a single histogram HGW is used for collecting statistics of weight gradients (i.e. gradients of the loss function with respect to weights of the network), and a second single histogram GX is used for collecting statistics of activation gradients (i.e. gradients of the loss function with respect to activations of the network). In this example implementation, the histograms are defined over the FP16 range, having as bin edges all exponents in the range [−24:15], although other ranges can be used in association with other floating point formats. For each gradient type, as subset 0≤L′≤L of the network layers is used for statistics gathering, such that at least one histogram is available.
Two alternative methods can be used in the present method to combine the bin count from both histograms. The first method combines the histogram bin values of the two histograms and determines whether the total proportion of bin values exceeding a cut-off bin C defining a threshold T, increasing the loss scaling factor if the following condition is satisfied:
where f is the critical threshold. Otherwise the scaling factor is reduced. Excluding the underflow count, this condition is written:
The second method compares the proportion of the bin count exceeding a respective cutoff C to a respective critical fraction f separately for each histogram and a joint decision to increase the loss scaling factor is only made if both tests pass (i.e. unanimous vote). Critical fractions fCX and fGW and cutoff bins CGX and CGW are assumed for activation and weight gradients, respectively. The loss scaling factor is increased if the following condition is met:
which is written as follows where underflow counts are excluded:
Activation gradients computed for a micro-batch of training data are dependent on the size of the micro-batch. As the batch size increases, the activation gradients become smaller. For the histogram of bin counts, doubling the batch size halves all the activation gradients, having the effect of shifting the histogram GX down by one exponent bit, leading to greater underflow.
By contrast, weight gradients are computed as an average of the per-micro-batch-element weight gradient estimates, which means that, on expectation, the weight gradient estimates will not change, when the batch size increases. Therefore, the histogram of weight gradient estimates GW is unchanged by a changing batch size. The performance of the automatic loss scaling therefore depends on the micro-batch size. If the weight and activation gradients statistics are combined by summing their bin counts as described above, then the contribution from the activation gradients to the collected statistics vary depending on the micro-batch size, both in terms of quantity and bin position. Furthermore, the ratio of weight gradients count to activation gradients count in the combined histogram is inversely proportional to the batch size, which means that for a higher batch size the information from the weight gradients will be more diluted, making it hard to create a robust implementation of the ALS algorithm.
To resolve this, the ALS algorithm can be constructed such that the ratio of weight gradients count to the activation gradients count only depends on the model definition and remains constant irrespective of the micro-batch size being used. Such a ratio is denoted as ρ(). Taking B=1 as a reference, where B is the size of the micro-batch, the activation gradients histogram is estimated per micro-batch element. This can be done by scaling the gradients by batch size B before gathering statistics. This does not recover any amount of underflow and comes with the additional cost of scaling the activation gradient tensors. Since automatic loss scaling uses cutoff bin edges, for a given ratio ρ(), the weight gradient histograms are scaled by B (since activation gradient counts scale with batch size) and allow the activation gradients cutoff bin edge to be reduced each time the batch size is increased. A sum-based condition for increasing the loss scaling factor according to this method can be written as:
or, excluding the underflow count:
A condition according to the ‘unanimous vote’ approach as described above can be written as:
or, excluding the underflow count:
As batch size increases, so does underflow of activation gradients, meaning that increasing the loss scaling to improve the representation of activation gradients causes a loss of representation at the upper end of the range, which results in gradient clipping of the weights gradients. Furthermore, as training evolves, and as the histograms diverge due to values becoming smaller or larger, the decision to reduce or increase the loss scaling factor is dominated by the statistics collection that exceeds the cutoff bin edges count threshold faster.
Alternatively, to manage the underflow of activation gradients with the saturation of weight gradients, the ALS algorithm can be designed to work with different scaling factors, αGX and αGW for activation gradients and weight gradients respectively. The loss is scaled by αGX while the result of the weight gradients calculation for a given layer are scaled by αGW/αGX to reflect the desired weight gradients scaling. The difference between the two scaling factors can be fixed to the micro-batch size or can be dynamically set based on the tensor statistics. Scaling based on tensor statistics allows maximisation of the use of the available dynamic range, in particular as the gradients diverge during training.
Furthermore, different loss scaling factors can be computed for different layers, or different blocks of the neural network. Statistics are gathered based on the chosen layer granularity, and these statistics are used to update the scaling factors for the scope's activation and weight gradients during the backward pass computation.
It should be noted that, while the above description relates to activation and weight gradients, the described techniques can be applied to determine a scale factor for any two quantities whose distributions diverge.
Adaptive Format Selection for the Forward Pass
In order to improve the accuracy of the quantization of weight and activations in the forward pass, a similar principle to the scaling factor described above may be applied to determine a representation for a set of values based on their statistics. In the forward pass, statistics of weights and activations are collected in order to maintain a separate histogram of the weights and activations, measuring the fraction of the total number of samples of the histogram that are above a given threshold; and adjusting the exponent offset (or exponent bias) accordingly to maintain a predefined fraction of samples above the given threshold. In general, the histograms will comprise a plurality of bins.
Adjusting the exponent bias of weights and activations shifts the representable range of values for these weights and activations. If the exponent bias or offset is increased, the range of representable values is shifted to lower magnitudes.
In the forward pass, histograms are collected for activations, gradients with respect to weights and gradients with respect to outputs, where the goal is to determine an appropriate format for representing these values. As described above for gradients, histograms have at least two bins with the histogram providing an aggregation of all values falling within the ranges indicated by each bin. In general, histograms will comprise more than two bins.
Histogram bins may be selected based on the format of the values being collected. For example, where the weights, activations and/or gradients are being converted from an FP16 format to an FP8 format, then the bin edges are selected at each power of two in the range of FP16 values. This includes exponents between the values of −24 and 15. If converting from an FP32 format to FP8, then bin edges for the range of values of FP32 would be chosen instead.
Histograms may be collected for a single layer of the network, or a single histogram can be collected with aggregate statistics for values of multiple layers, assuming that the layers combined in the aggregation use the same format for the relevant values. Sets of layers may be selected heuristically, or may be determined automatically by the computer program.
Training and implementation of a neural network model may be performed on a set of multiple processors, with each processor processing a subset of the data, for example with each processor handling a mini-batch of data, and each processor having a local replica of the neural network model. Each processor may compute its own histogram with the gradients for the respective subset of the data. Histograms may be communicated to other processors at the end of mini-batch computation in order to determine aggregate statistics based on which a common representation for gradients and/or activations and weights may be determined and applied when converting said values to a particular format. Bin counts may be represented in the form of a raw count, or as a proportion by dividing the counts by the sum of all bins of the aggregated histogram. Communication overhead for sending histograms to other processors must be balanced by the computational advantage of having an optimal scaling factor when deciding on frequency at which statistics need to be computed and communicated.
As described above for gradients, one criterion on which to determine an appropriate representation of a value based on the collected histogram is to apply a threshold, and to reduce the scaling factor applied in response to a number, or more typically, proportion of values in the histogram exceeding the chosen threshold, indicating a degree of overflow. Other criteria may be determined for the collected values and used to adjust the representation of the values in the next stage of training. Some examples of such criteria include mean-square error, signal-to-noise ratio, degree of underflow, and Kullback-Leibler divergence.
In addition to selecting the bias for the exponent of weights and activations which are to be expressed in a floating-point format such as FP8, the statistics may be used to select the type of format to be used. The spread of values in the histograms may indicate the most appropriate format for representing the values. As mentioned above, floating-point formats may allocate different numbers of bits to represent the mantissa and the exponent, which provide different representable ranges and different numerical precision within those ranges. For an 8-bit floating-point format, two possible formats are 1.4.3, which uses one sign bit, e=4 exponent bits and m=3 mantissa bits, and 1.5.2, which uses one sign bit, e=5 exponent bits and two mantissa bits. By collecting histograms in a forward pass, it is possible to analyse the range of values that need to be represented and select between different formats according to the range. In general, it is desirable to choose the format that represents with as high a numerical precision as possible and for which most of the values within the range can be represented. An appropriate choice of exponent bias may be determined for each of multiple candidate formats, for example an exponent bias can be determined for both of 1.4.3 and 1.5.2. The format may be selected from the set of candidate formats using the same or different criteria as those described above for selection of the exponent bias. In the event of more than one format having the same performance according to the given criterion, the format with the smallest exponent field size may be chosen, as this maximises precision by allowing more mantissa bits to represent the given number.
Once a format is determined, i.e. a scaling factor, exponent bias, and/or appropriate allocation of bits to represent the exponent and mantissa, these can be applied to the respective values for subsequent steps of training the neural network. These may be applied only to a subset of layers of the network, for example those in which matrix multiplications and convolutions occur, as these layers are compute intensive and a lower precision format is most effective in improving the efficiency of such operations. The representation is applied when the given values are converted to the new format, for example when converting weights or activations to FP8 before performing a convolution operation.
A first aspect disclosed herein provides a computer-implemented method of training, based on a set of training data, a multi-layer neural network comprising a set of network weights, the method comprising: processing the training data in respective forward and backward passes through a sequence of layers of the network, the forward pass comprising computing a set of activations by applying an activation function in dependence on the network weights and training data, and the backward pass comprising: computing gradients of a pre-determined loss function with respect to the network weights and/or computing gradients of the pre-determined loss function with respect to the computed activations of the network, wherein an adjustment parameter is applied to at least a subset of values in the neural network, the values comprising at least one of: the network weights, the activations computed in the forward pass, the gradients with respect to activations computed in the backward pass, and the gradients with respect to weights computed in the backward pass; updating the network weights in dependence on the computed gradients with respect to the weights; computing a proportion of the subset of values falling above a predefined threshold; and updating the adjustment parameter applied to the subset of machine learning parameters in dependence on the computed proportion.
In embodiments, the adjustment parameter is a scale factor, and wherein the scale factor is applied on the backward pass to at least a subset of the gradients with respect to the activations and/or the gradients with respect to the network weights, wherein the scale factor is updated in dependence on the proportion of the gradients of that subset that have a value falling above a pre-defined threshold.
In embodiments, the adjustment parameter is a scale factor, and the scale factor is applied on the backward pass to at least a subset of the gradients with respect to at least one of the activations and the gradients with respect to the network weights, wherein the scale factor is updated in dependence on the proportion of the gradients of that subset that have a value falling above a pre-defined threshold.
In embodiments, the method comprises applying the scale factor to at least one of gradients with respect to weights and gradients with respect to activations of all layers of the network by multiplying the loss function by the scale factor.
In embodiments, the method comprises constructing a histogram of gradients, the histogram comprising a plurality of bins, wherein the scale factor is updated based on a proportion of gradients occupying bins above a threshold value.
In embodiments, the method comprises constructing a respective histogram of gradients for each layer of the neural network, wherein the proportion of gradients occupying each of a set of bins for each histogram is input to an accumulator to obtain an aggregated proportion for each bin, the scale factor being derived by computing an aggregated proportion occupying bins above an overall threshold.
In embodiments, the method comprises constructing a respective histogram of gradients for each layer, wherein for each layer a respective layer-wise scale factor is applied during the backward pass, the layer-wise scale factor being updated based on a proportion of gradients in the histogram for the corresponding layer occupying bins above a corresponding layer-wise threshold value.
In embodiments, the method is implemented on a plurality of processors, wherein each processor processes a respective subset of the training data in each of the forward and backward passes, and computes a respective histogram of gradients for the corresponding subset of the training data, each histogram having defined a common set of bins, wherein the proportion of gradients occupying each bin of the set of bins defined for each histogram is aggregated to obtain an aggregated proportion for each bin, with a scale factor being derived by computing an aggregated proportion occupying bins above an overall threshold.
In embodiments, the method comprises storing at least a subset of the network weights, gradients and activations in computer memory in floating-point format.
In embodiments, the method comprises storing at least a subset of the network weights, gradients and activations in computer memory in eight-bit floating-point format.
In embodiments, the method comprises storing at least a subset of the network weights, gradients and activations in computer memory in sixteen-bit floating-point format.
In embodiments, the method comprises storing the subset of values in a floating-point format, and wherein the adjustment parameter is an exponent bias applied to the floating-point representations of the subset of weights, gradients and activations.
In embodiments, the subset of values in the neural network is a subset of network weights and activations and the adjustment parameter is an exponent bias applied to the subset of values of the network weights and activations in the forward pass.
In embodiments, a subset of network weights, activations and gradients which are inputs to compute operations in at least one of the forward and backward passes are stored in eight-bit floating-point format, the compute operations comprising at least one of a matrix operation and a convolution operation .
A second aspect herein provides a computer system comprising one or more processors configured to train a multi-layer neural network comprising a set of network weights, and memory holding the network weights, the processor configured to train the neural network by:
updating the adjustment parameter applied to the subset of machine learning parameters in dependence on the computed proportion.
In embodiments, the computer system comprises a plurality of processors, wherein each processor is configured to process a respective subset of the training data.
In embodiments, the adjustment parameter is updated in dependence on an aggregated proportion of values for all processors falling above a predefined threshold, the aggregated proportion computed by aggregating a computed proportion of the subset of values falling above the predefined threshold for each of the plurality of processors.
A further aspect of the present disclosure provides a non-transitory computer-readable storage medium storing computer program instructions which when executed perform a method of training, based on a set of training data, a multi-layer neural network comprising a set of network weights, the method comprising:
The present application claims priority to U.S. Provisional Patent Application No. 63/265,436 filed Dec. 15, 2021, the disclosure of which is hereby incorporated herein by reference.
Number | Date | Country | |
---|---|---|---|
63265436 | Dec 2021 | US |