Embodiments described herein relate to methods and systems for remote training of a machine learning model.
Deep learning is a subset of Machine Learning (ML) in which large datasets are used to train ML models in the form of Neural Networks (NNs). NNs are a connected system of functions whose structure is inspired by the human brain. Multiple nodes are interconnected with each connection able to transmit data in a similar way as signals are transmitted via synapses. Connections between nodes carry weights which are the parameters being optimised, consequently training the model.
Federated Learning (FL) and Distributed Learning (DL) are more recent decentralised ML frameworks aiming to parallelise the training process using multiple devices simultaneously to train a single model. In such approaches, raw data collected by edge devices in the Internet of Things (IoT) may be communicated back to a server in order to train a global model at the server. The volume of data produced and communicated by such edge devices is only increasing, and so a strategy is required to cope with such an increase. One solution to this problem is provided by the increase in computing power of edge devices in the IoT, which allows for the training process to be moved from the cloud (server) to the edge. Training the model at the edge devices can help to reduce the amount of raw data being sent through the network, providing benefits with regards to decreasing communication and preservation of privacy.
In a conventional approach, the worker nodes ‘push’ model updates to the Parameter Server after each training iteration. This can be understood by reference to
The worker node receives the global model (step S207) and carries out a training iteration (step S209). During the training iteration, a set of training data is input to the machine learning model and is processed to produce an output. The training data is data for which an expected output is known in advance. For example, the machine learning algorithm may be a classifier algorithm used to distinguish between two types of image A and B, and the set of training data may comprise a set of images labelled as type A or type B; in this case, the expected outcome will be for images labelled as to be duly classified as type A by the machine learning algorithm and images labelled as B to be classified as type B. If the parameters of the model are properly optimised, the machine learning model should have a high level of performance i.e. it will be able to classify each image as either type A or type B with a high degree of accuracy, meaning that the output of the model will be very close to the expected output. In other examples, the machine learning algorithm might be a regression algorithm, in which the performance is determined based on a mean squared error or mean absolute error between the output from the algorithm and the expected output. In other cases, the machine learning algorithm might be an unsupervised learning algorithm used for anomaly detection, effectively detecting a deviation from normal data through the use of deep autoencoders.
The difference between the output from the machine learning model and the expected output is monitored and used to compute updated values for the model parameters, with the updated values being chosen so as to reduce the difference between the output from the model and the expected output in subsequent training iterations (step S211). A number of known techniques may be used to compute the updated values for the model parameters. One example is Gradient Descent (GD), the most common optimisation approach for learning the weights to use in a machine learning model such as a Neural Network.
In step S213, having computed the updated values for the model parameters, these are sent over the communication network to the server. The server receives the updated model parameters (step S215) and aggregates the updates with the global model (step S217) in order to update the global model (step S219). The process then repeats from step S201.
In the example method shown in
To address the above problems, some methods have been proposed that aim to reduce the communication requirements within a distributed system. These methods include Distributed Selective Stochastic Gradient Descent (DSSGD) and AdaComp. These methods are an improvement to the method shown in
In general, it is desirable to provide new methods for Federated and/or Distributed Learning that can help to relieve the need for transmitting large volumes of data from the edge devices to the server.
Embodiments of the invention will now be described by way of example with reference to the accompanying drawings in which:
According to a first embodiment, there is provided a computer-implemented method for training a machine learning model, the method comprising:
The method may further comprise determining, for each training iteration, the value of a performance parameter for the model;
The change in the value of the performance parameter may be a change in the value of the performance parameter between the training iteration and the immediately preceding training iteration.
The value of the performance parameter in each training iteration may be reflective of the difference between the output from the model and an output expected from processing the training data.
The performance parameter may define the loss obtained for the training iteration.
The value of the performance parameter in each training iteration may define an extent to which the values of one or more parameters of the model are changed as a result of updating the parameters in the respective training iteration.
The updated parameters may only be sent to the remote computing device in the event that the change in the value of the performance parameter is below a first threshold.
The first threshold may be defined with respect to a degree of variance in the values of the performance parameter for two or more previous training iterations.
The first threshold may be weighted by a factor whose value reflects a degree of connectivity available in a network including the communication channel.
The updated parameters may only be sent to the remote computing device in the event that the change in the value of the performance parameter is also above a second threshold.
The second threshold may be defined with respect to a degree of variance in the values of the performance parameter for two or more previous training iterations.
The second threshold may be weighted by a factor whose value reflects a degree of connectivity available in a network including the communication channel.
The machine learning model may comprise a neural network. The one or more parameters of the model may comprise one or more weights or biases of the neural network.
The remote computing device may be configured to update a global machine learning model based on updates received from the computing device. In the event it is determined to send the updated model parameters to the remote computing device, the method further may comprises requesting an updated version of the global model from the remote computing device, and performing a next training iteration using the updated version of the global model.
According to a second embodiment, there is provided a non-transitory computer readable storage medium comprising computer executable instructions that when executed by one or more computer processors will cause the one or more processors to carry out a method according to the first embodiment.
According to a third embodiment, there is provided a computing device comprising:
According to a fourth embodiment, there is provided a system comprising:
Embodiments described herein are based on the intuition that whilst the model parameters may be subject to continual updates, there will be some periods of time where the model's accuracy will change more significantly between consecutive iterations. When the difference between the expected output of the model and the actual output is large, this will indicate a need for more drastic alteration of the model parameters. At other times, the difference may be such as to only prompt a small change in the model parameters for the next training iteration. At these points it can be deduced that the model parameters are more stable. The model parameters may include one or more of the weights, biases and gradients of model.
Although it may seem advantageous to update the global model when there has been a significant shift in the model parameters between training iterations, it is the parameter values associated with the model's beginning to stabilise that are more relevant for updating the global model. Applying this reasoning throughout the training process means that a worker node should only seek to forward the values of the model parameters to the server once the local model at the node is understood to be in a stable state. By continuing to monitor the degree of stability in the model as it progresses through training iterations, the worker node can make an informed decision regarding the status of the model and whether to forward updates to the server.
In step S307, the worker node receives the global model and in step S309 proceeds to update a local version of the model at the worker node with the global model parameters. The local model is stored by the worker node and referred to when making future decisions about when to update the global model. As discussed below, maintaining a copy of the model at the worker allows for multiple iterations of training to be carried out locally without having to forward the parameter values to the server after each iteration; this contrasts with the method shown in
In step S311, the worker node performs a training iteration using a set of training data. Based on the output from the model, the worker computes a set of updated values for the model parameters. The worker node then updates the local model with the new values of the model parameters (step S313).
In step S315, the worker node determines a stability of the model (step S315). The stability reflects the degree to which the model parameters are seen to vary across successive training iterations. Depending on the extent of variation, and hence the stability of the model, a decision is made as to whether or not to send the updated parameters from the latest iteration to the server (step S317).
In more detail, the accuracy of the local model for a given training iteration can be recorded by the worker. If using a Gradient Descent algorithm to optimise the model parameters, for example, the accuracy or performance of the model at a given point in the training process can be determined by reference to the value of the loss function for that iteration. Observing the recorded loss for consecutive computed models enables the worker to define a loss difference Δθ=θi−θi-1, where θi is the loss as measured for the current training iteration and θi-1 is the loss as measured for the preceding training iteration. Here, the loss θi acts as a performance parameter for the present training iteration, whilst the change in value Δθ of that performance parameter Δθ will give an insight into the current status of the model. This knowledge can be exploited by the worker to make an intelligent decision on an action to take. The possible actions are described as the following:
1. If the loss difference Δθ is large, the performance of the model is changing significantly between iterations and therefore the model is unstable. Rather than updating the global model with parameters that will become obsolete shortly, the decision can be made not to forward the current values of the model parameters and instead to continue to train the model locally.
2. If the loss difference Δθ is small, the performance of the model is determined to be more consistent and hence it can be assumed that the model is in a stable state. The parameters for this model are not changing as drastically between iterations so it would be suitable to forward the values of the model parameters to the server as this information is unlikely to become outdated as quickly.
The worker may calculate the loss difference and take one of these two actions. The decision of which action to take is based on the stability and can be decided through a statically defined threshold between loss differences.
Thus, if the model is determined to be stable, the updates are sent to the server (step S319). The server receives the updated local model (step S321), aggregates the global model with the received local update from the worker node (step S323) and updates the global model accordingly (step S325). The method then returns to step S301. Referring back to step S317, if the decision is taken not to send the updated parameters to the server, the method returns to step S311 and a new training iteration is carried out. The process then continues to repeat steps S311 to S317 until a decision is reached to send the updated model parameters from a particular training iteration to the server.
In some embodiments, the step of determining whether or not the model is stable may be made through a plateau detection mechanism. In this case, the loss θ measured at each iteration (i.e. the measure of the difference between the expected output of processing the training data and the actual output from the machine learning model) is recorded to obtain a time series of data showing the variation in loss as a function of time. An example of this is shown in
One way of implementing the above plateau detection is by observing the variance σloss2 in the losses of the local model at the worker. The variance σloss2 can be determined by observing the loss values θi obtained for training iterations within a certain time window, where the length of the window is defined as the number of training iterations required to accurately detect a plateau. From this, the stability of the model can be deduced whilst taking into account any fluctuations in model accuracy (performance) that may arise due to the continuous introduction of new training data. Having determined the variance σloss2, the next step will be to define a value j, such that in the event that the loss recorded for a subsequent training iteration is less than jσloss, the model can be considered to have stabilised. The value of j can be varied dependant on factors such as network quality and resource availability.
By implementing the steps described above, embodiments can reduce the volume of data being sent via the communication channel between the node and the server. The frequency of updates will be reduced by only forwarding the updates when the local model is stable. The benefit of this strategy can be appreciated in IoT deployments with low communication bandwidth or limits on available energy for device transmission. If every worker is consistently pushing updates to the server, there will be a significant time delay between iterations of training. This will be even more noticeable if the number of workers is large and the channel is occupied for an extended period before a worker has the opportunity to send an update. By communicating less often, this allows for workers to spend time computing local training steps that they would otherwise be using to wait for the channel to become available. The proposed methodology can be applied in the context of both Federated Learning (with data being generated by edge devices) and distributed learning (with the data being distributed by a centralised node) architectures.
In some embodiments, once the model reaches stability (as recognised by the plateau detection mechanism, for example) the frequency of updates can be reduced also. If the model is not changing significantly between consecutive iterations, it will not be necessary to communicate the update from each training iteration to the server. Nevertheless, the decision may be taken to forward an update where there is seen to be a change whilst in the stable phase of training; in other words, where the model parameters are seen to diverge from what were previously stable values. Changes that occur once the model has been seen to enter a stable phase will define the “consistency” of training.
The step of determining whether or not the training remains consistent can be implemented in a similar way to the plateau detection mechanism. As before, the variance σloss2 in the time series of observed losses of the local model can be determined over a pre-defined window. Having previously identified the model to be stable, if a loss encountered for a subsequent training iteration is then found to be greater than kσloss, where k is a constant, then this will signify that the model is training consistently i.e. it is changing enough for a further update to be considered worthwhile.
The consistency of training can be understood with reference to
The parameter values for detecting stability and consistency, j and k respectively, may be defined for a given training iteration i such that:
kσloss<Δθi<jσloss
where as before, Δθ, is the change in performance parameter θ i.e. the difference between the observed loss for the model in the training iteration i and the observed loss for the preceding iteration i−1. If the condition above is satisfied by a particular training iteration, then the model will be determined to be both stable and consistent at that point in time.
The effects of implementing the steps of
It will be appreciated that the model can still be considered stable and the training consistent if the loss rises over time, as well as falls. In the period from t7 to t8, for example, there is a rise in the loss with successive iterations. Here, the rate of growth is small enough for the model to be considered stable, but large enough for the training to be consistent i.e. kσloss<Δθi<jσloss. During the period t7 to t8, therefore, the outcome of steps S717 and S719 of
By virtue of the steps above, the system will only communicate updates when the model is stable and will only communicate important changes from then onwards.
In some embodiments, rather than defining the loss per se, the performance parameter may be defined based on an extent of change in the values of one or more of the model parameters following the updates made in the present training iteration. For example, if the values of the one or more parameters in the present iteration i are found to change by an average amount Δparamsi, and the values of the one or more parameters in a previous training iteration i−1 are found to change by an average amount Δparamsi-1, then depending on the difference Δparamsi−Δparamsi-1, a decision can be taken as to whether or not the model is stable. Here, the variance in the one or more model parameters can be used to analyse fluctuation in the model during training. For example, if the value Δparamsi−Δparamsi-1 is greater than pαparams, where p is a constant and σparams is the variance in Δparams over previous training iterations, this again can indicate that training is consistent and it is worthwhile updating the global model with the current values of those parameters. The value of p can be related to the network quality and resource constraints. If Δparams is less than pσparams, then a decision may be taken not to forward an update and instead to continue to train locally until there is an update of significance.
It will be appreciated that the particular performance parameter used may vary with the type of machine learning algorithm; in addition to the parameters θ and Δparams described above, other parameters may be chosen, such as a mean absolute error, mean squared error, categorical cross entropy, or accuracy performance on a test dataset for inference. The performance parameter may serve to reflect a difference between the output from the machine learning model and an expected output. In the case of supervised learning, for example, the expected output would be the ground truth labels associated with the training data.
Embodiments described herein provide significant improvements in model convergence time as updates are not sent from the worker node(s) to the server when the local model has not converged and is in an unstable state. By not including unnecessary updates, communication costs and times are also reduced.
The improvement in training time and communication costs as provided by embodiments described herein is further illustrated in
Referring to
Embodiments described herein have extensive applications of FL in IoT. For example, in autonomous driving, vehicles can observe complex surroundings to train a global model for collision avoidance. In environments where network quality is low, it is crucial for the vehicle to remain connected as detachment from the global model could be dangerous. Another prospective application for FL is predictive maintenance of edge devices. Each device uses data, that it is continuously collecting, to perform local training to update a global model. Devices network configurations with the global server may vary such as Ethernet, 5G, LoRaWAN. Embodiments described herein can be used to predict health and possible failures of the device in order to plan for maintenance.
Implementations of the subject matter and the operations described in this specification can be realized in digital electronic circuitry, or in computer software, firmware, or hardware, including the structures disclosed in this specification and their structural equivalents, or in combinations of one or more of them. Implementations of the subject matter described in this specification can be realized using one or more computer programs, i.e., one or more modules of computer program instructions, encoded on computer storage medium for execution by, or to control the operation of, data processing apparatus. Alternatively or in addition, the program instructions can be encoded on an artificially generated propagated signal, e.g., a machine-generated electrical, optical, or electromagnetic signal that is generated to encode information for transmission to suitable receiver apparatus for execution by a data processing apparatus. A computer storage medium can be, or be included in, a computer-readable storage device, a computer-readable storage substrate, a random or serial access memory array or device, or a combination of one or more of them. Moreover, while a computer storage medium is not a propagated signal, a computer storage medium can be a source or destination of computer program instructions encoded in an artificially generated propagated signal. The computer storage medium can also be, or be included in, one or more separate physical components or media (e.g., multiple CDs, disks, or other storage devices).
While certain embodiments have been described, these embodiments have been presented by way of example only and are not intended to limit the scope of the invention. Indeed, the novel methods, devices and systems described herein may be embodied in a variety of forms; furthermore, various omissions, substitutions and changes in the form of the methods and systems described herein may be made without departing from the spirit of the invention. The accompanying claims and their equivalents are intended to cover such forms or modifications as would fall within the scope and spirit of the invention.