Unless otherwise indicated, the subject matter described in this section is not prior art to the claims of the present application and is not admitted as being prior art by inclusion in this section.
Deep neural networks (DNNs), which are machine learning (ML) models composed of multiple layers of interconnected nodes, are widely used to solve tasks in various fields such as computer vision, natural language processing, telecommunications, bioinformatics, and so on. A DNN is typically trained via a stochastic gradient descent (SGD)-based optimization procedure that involves (1) randomly sampling a batch (sometimes referred to as a “minibatch”) of labeled data instances from a training dataset, (2) forward propagating the batch through the DNN to generate a set of predictions, (3) computing a difference (i.e., “loss”) between the predictions and the batch's labels, (4) performing backpropagation with respect to the loss to compute a gradient, (5) updating the DNN's parameters in accordance with the gradient, and (6) iterating steps (1)-(5) until the DNN converges (i.e., reaches a state where the loss falls below a desired threshold). Once trained in this manner, the DNN can be applied during an inference phase to generate predictions for unlabeled data instances.
Generally speaking, the use of larger datasets for training results in more accurate DNNs. However, as the amount of training data increases, the computational overhead and time needed to carry out the SGD training procedure also rises. To address this, importance sampling has been proposed as a technique for accelerating the training of DNNs. With importance sampling, each data instance in the training dataset is assigned a sampling probability that corresponds to the “importance” of the data instance to the training procedure, or in other words the degree to which that data instance contributes to progress of the training towards model convergence. Then, at each training iteration, data instances are sampled from the training dataset based on their respective sampling probabilities rather than at random, thereby causing more important data instances to be selected with higher likelihood than less important data instances and leading to an overall reduction in training time. It has been found that the optimal sampling probability for a given data instance is proportional to the norm (i.e., size) of the gradient computed for that data instance via SGD.
One challenge with implementing importance sampling is that it is impractical to compute exact gradient norms (and thus, optimal sampling probabilities) for an entire training dataset at each training iteration, because this requires time-consuming forward and backpropagation passes through the DNN for every data instance in the training dataset. Current importance sampling approaches attempt to work around this problem using various methods but suffer from their own set of limitations (e.g., reliance on outdated/stale gradient norm information, inability to support batches, etc.) that adversely affect training performance.
In the following description, for purposes of explanation, numerous examples and details are set forth in order to provide an understanding of various embodiments. It will be evident, however, to one skilled in the art that certain embodiments can be practiced without some of these details or can be practiced with modifications or equivalents thereof
Embodiments of the present disclosure are directed to techniques for implementing importance sampling via ML-based gradient approximation. In one set of embodiments, these techniques include (1) training a DNN on a training dataset using SGD and (2) in parallel with (1), training a separate ML model (referred to herein as a “gradient approximation model” or “GAM”) that is designed to predict gradient norms (or gradients) for the data instances in the training dataset. The training of the gradient approximation model can be based on exact gradient norms/gradients computed for a subset of data instances via forward and backpropagation passes through the DNN.
The techniques further include (3) applying the gradient approximation model to the training dataset on a periodic basis to generate gradient norm/gradient predictions for the data instances in the training dataset and (4) using the gradient norm/gradient predictions to update sampling probabilities for the data instances. Steps (3) and (4) can be performed concurrently with (1) and (2). The updated sampling probabilities can then be accessed during the ongoing training of the DNN (i.e., step (1)) to perform importance sampling of data instances and thereby accelerate the training procedure.
As noted in the Background section, importance sampling is an enhancement to conventional SGD-based training that involves assigning a sampling probability to each data instance in the training dataset. This sampling probability indicates the importance, or degree of contribution, of the data instance to the training procedure. For instance,
However, implementing importance sampling in practice is difficult because determining the optimal sampling probability for each data instance—which is proportional to the gradient norm computed for that data instance via SGD—is a time-consuming task. Current importance sampling approaches employ a number of workarounds that mitigate the cost of updating sampling probabilities, but these approaches are susceptible to poor probability accuracy in some scenarios and/or introduce other performance problems.
To address the foregoing,
Workflow 500 of
Starting with workflow 500, at steps 504 and 506, computer system 102 can sample a batch of data instances from training dataset 106 based on their current sampling probabilities and use this batch to train DNN 104 via the standard SGD-based training procedure described at steps 304-310 of
Concurrently with steps 504 and 506, computer system 102 can sample a data instance from the batch used to train DNN 104 (step 508) and obtain a representation of the current state of DNN 104 (step 510). In one set of embodiments, this representation can include exact and up-to-date values for all of the DNN's parameters. In other embodiments, this representation can include an approximation or subset of the DNN's current parameter values, such as a sketch, random sub sample of parameters, etc.
At step 512, computer system 102 can forward propagate the data instance and the DNN state representation through GAM 502, resulting in a gradient norm prediction 514 for those inputs. In addition, at step 516, computer system 102 can perform a forward and backpropagation pass through DNN 104 with respect to the data instance, thereby computing a gradient norm 518 for the data instance.
Upon obtaining gradient norm prediction 514 and gradient norm 518, computer system 102 can compute a loss between these two values (step 520). Finally, computer system 102 can perform backpropagation through GAM 502 with respect to the loss determined at step 520 to compute a gradient and can update the parameters of GAM 502 based on the gradient (step 522). Computer system 102 can thereafter iterate steps 508-522 in order to further train GAM 502 until the training of DNN 104 is complete or some other termination criterion is fulfilled, such as reaching an accuracy threshold or number of training iterations threshold for GAM 502.
Turning now to workflow 550, at steps 552 and 554, computer system 102 can obtain the entirety of training dataset 106 (or specific data instances therein) and a representation of the current state of DNN 104 and provide these as inputs to GAM 502. As mentioned previously, this state representation can include current and exact values for all of the parameters of DNN 104 or some approximation/subset of those parameter values.
At step 556, computer system 102 can forward propagate training dataset 106 and the DNN state representation through GAM 502, resulting in a set of gradient norm predictions 558. Computer system 102 can then update the sampling probabilities for the data instances in training dataset 106 (i.e., {p1, . . . , pn} based on their respective gradient norm predictions (step 560) and use the updated sampling probabilities as part of its ongoing training of DNN 104 (steps 504 and 506). Finally, although not explicitly shown, computer system 102 can repeat steps 552-560 on a periodic basis in order to ensure that the sampling probabilities in training dataset 106 are kept relatively up to date with the current state of DNN 104.
It should be noted that the training of GAM 502 via workflow 500 and the application of GAM 502 for importance sampling via workflow 550 can be performed mostly or entirely in parallel. In certain embodiments, GAM 502 can be trained for a number of iterations prior to being used to update sampling probabilities in training dataset 106. For instance, once the accuracy of GAM 502 reaches a desired level (or in other words, the loss computed at step 520 of workflow 500 falls below a threshold), workflow 550 can be initiated.
The remaining sections of this disclosure provide additional implementation details regarding the high-level workflows shown in
Further, although
Yet further, in certain embodiments GAM 502 may be configured to predict gradients, rather than gradient norms, for data instances in training dataset 106. The gradient predictions output by GAM 502 can then be used to compute gradient norm predictions 514 and 558 shown in workflows 500 and 550 (by applying a norm function to the gradient predictions). While this approach can increase the size and complexity of GAM 502, it can also be leveraged to increase the batch size used to train DNN 104 (and thus further accelerate its training) without significantly adding to the computational overhead of the training procedure.
For example, assume that the batch size for training DNN 104 is originally set at 50 data instances and increased to 100 data instances. In this scenario, 50 of the data instances may be forward and back propagated through DNN 104 in order to compute their exact gradients via SGD, while the remaining 50 data instances may be forward propagated through GAM 502 in order to generate predicted/approximated gradients for those data instances. The exact and predicted/approximated gradients can then be combined and applied to update the parameters of DNN 104. Because the forward pass through GAM 502 is less resource intensive than performing both forward and backpropagation passes through DNN 104, this approach will not be significantly more expensive than solely computing exact gradients for the original batch size of 50, and yet will likely achieve faster convergence of DNN 104 due to the consideration of 50 additional data instances per batch.
At step 602, computer system 102 can sample a data instance xj from a batch B of data instances used to train DNN 104. In addition, at step 604, computer system 102 can obtain a representation of the current state of DNN 104. As noted previously, this representation can include the entire/exact state of DNN 104 (i.e., exact versions of all of its current parameter values) or an approximation or subset thereof. For example, this approximation or subset may be obtained via sketching, random subsampling, sparsification, or quantization of the original parameter values.
At step 606, computer system 102 can forward propagate data instance xj and the DNN state representation through GAM 502, thereby generating a gradient norm prediction r′j for xj. Computer system 102 can further forward propagate data instance xj through DNN 104 to generate a prediction for xj (step 608), compute a loss between the prediction and xj's label yj (step 610), perform backpropagation through DNN 104 with respect to the loss to compute a gradient g (step 612), and compute the norm of the gradient (i.e., rj) (step 614).
At steps 616 and 618, computer system 102 can compute a loss between gradient norm prediction r′j and gradient norm rj and can perform backpropagation through GAM 502 with respect to this loss to compute a gradient g′. Finally, computer system 102 can update the parameters of GAM 502 in accordance with gradient g′ (step 620) and flowchart 600 can end.
Starting with steps 702 and 704, computer system 102 can obtain the entirety of training dataset 106 (or a subset of data instances in the training dataset) and a representation of the current state of DNN 104. Computer system 102 can then forward propagate training dataset 106 and the DNN state representation through GAM 502, resulting in a set of gradient norm predictions {r′1, . . . r′n} corresponding to data instances {x1, . . . , xn} (step 706).
At step 708, computer system 102 can enter a loop for each data instance xi in training dataset 106. Within this loop, computer system 102 can compute an updated sampling probability pi for data instance xi based on its corresponding gradient norm prediction r′i (step 710). For example, in one set of embodiments pi can be computed as follows:
Computer system 102 can then store updated sampling probability pi for xi in training dataset 106 (thereby overwriting the previous value for pi) (step 712) and reach the end of the current loop iteration (step 714). Once all of the data instances in training dataset 106 have been processed via this loop, flowchart 700 can end.
As mentioned in section (2), there are several ways in which DNN 104, training dataset 106, and GAM 502 may be deployed across different computer systems. For example, in a first scenario, a first computer system C1 may hold DNN 104 and a second computer system C2 may hold training dataset 106 and GAM 502. In a second scenario, computer system C1 may hold DNN 104 and training dataset 106 and computer system C2 may hold GAM 502. And in a third scenario, computer system C1 may hold training dataset 106, computer system C2 may hold GAM 502, and a third computer system C3 may hold DNN 104. In these various scenarios, the processing steps performed by computer system 102 on DNN 104 and GAM 502 can instead be performed by the computer systems holding these respective models.
Regarding the first scenario above, in some embodiments the computer system holding DNN 104 (i.e., C1) can send DNN parameter updates to the computer system holding GAM 502 (i.e., C2), rather than the entirety of the DNN's state (which is needed as an input to GAM 502 in both workflows 500 and 550). Computer system C2 can then reconstruct the full state of DNN 104 using the parameter updates and a local copy of the prior state of DNN 104 and provide the reconstructed state as input to GAM 502. This advantageously reduces the amount of data that needs to be transmitted between these computer systems.
Regarding the second and third scenarios above, in some embodiments the computer system holding GAM 502 (i.e., C2) can send a copy of the current state of GAM 502 to the computer system holding training dataset 106 (i.e., C1) at the start of workflow 550, rather than having C1 send training dataset 106 to C2. Computer system C1 can then perform the steps of workflow 550 (e.g., determination of gradient norm predictions and updating of sampling probabilities) on its local copy of GAM 502 and training dataset 106. This will generally be more efficient in terms of network bandwidth than sending training dataset 106 from C1 to C2 in order to carry out workflow 550 at C2, because in many real-world scenarios training dataset 106 will be very large in size.
Certain embodiments described herein can employ various computer-implemented operations involving data stored in computer systems. For example, these operations can require physical manipulation of physical quantities—usually, though not necessarily, these quantities take the form of electrical or magnetic signals, where they (or representations of them) are capable of being stored, transferred, combined, compared, or otherwise manipulated. Such manipulations are often referred to in terms such as producing, identifying, determining, comparing, etc. Any operations described herein that form part of one or more embodiments can be useful machine operations.
Further, one or more embodiments can relate to a device or an apparatus for performing the foregoing operations. The apparatus can be specially constructed for specific required purposes, or it can be a generic computer system comprising one or more general purpose processors (e.g., Intel or AMD x86 processors) selectively activated or configured by program code stored in the computer system. In particular, various generic computer systems may be used with computer programs written in accordance with the teachings herein, or it may be more convenient to construct a more specialized apparatus to perform the required operations. The various embodiments described herein can be practiced with other computer system configurations including handheld devices, microprocessor systems, microprocessor-based or programmable consumer electronics, minicomputers, mainframe computers, and the like.
Yet further, one or more embodiments can be implemented as one or more computer programs or as one or more computer program modules embodied in one or more non-transitory computer readable storage media. The term non-transitory computer readable storage medium refers to any storage device, based on any existing or subsequently developed technology, that can store data and/or computer programs in a non-transitory state for access by a computer system. Examples of non-transitory computer readable media include a hard drive, network attached storage (NAS), read-only memory, random-access memory, flash-based nonvolatile memory (e.g., a flash memory card or a solid state disk), persistent memory, NVMe device, a CD (Compact Disc) (e.g., CD-ROM, CD-R, CD-RW, etc.), a DVD (Digital Versatile Disc), a magnetic tape, and other optical and non-optical data storage devices. The non-transitory computer readable media can also be distributed over a network coupled computer system so that the computer readable code is stored and executed in a distributed fashion.
Finally, boundaries between various components, operations, and data stores are somewhat arbitrary, and particular operations are illustrated in the context of specific illustrative configurations. Other allocations of functionality are envisioned and may fall within the scope of the invention(s). In general, structures and functionality presented as separate components in exemplary configurations can be implemented as a combined structure or component. Similarly, structures and functionality presented as a single component can be implemented as separate components.
As used in the description herein and throughout the claims that follow, “a,” “an,” and “the” includes plural references unless the context clearly dictates otherwise. Also, as used in the description herein and throughout the claims that follow, the meaning of “in” includes “in” and “on” unless the context clearly dictates otherwise.
The above description illustrates various embodiments along with examples of how aspects of particular embodiments may be implemented. These examples and embodiments should not be deemed to be the only embodiments and are presented to illustrate the flexibility and advantages of particular embodiments as defined by the following claims. Other arrangements, embodiments, implementations, and equivalents can be employed without departing from the scope hereof as defined by the claims.