This patent application claims the benefit and priority of Singaporean Patent Application No. 10202260574Y filed with the Intellectual Property Office of Singapore on Dec. 23, 2022 and claims the benefit and priority of Singaporean Patent Application No. 10202251220D filed with the Intellectual Property Office of Singapore on Sep. 29, 2022, the disclosures of which are incorporated by reference herein in their entireties as part of the present application.
Various aspects of this disclosure relate to systems and methods for training a machine learning model.
BACKGROUND
Federated learning (FL) provides general principles for decentralized clients to train a server model collectively without sharing local data. FL is a promising framework with practical applications, but its standard training paradigm requires the clients to back-propagate through the model to compute gradients. Since these clients are typically edge devices and not fully trusted, they may lack the computational and storage resources required to execute back-propagation. For example, any trusted execution environment may not have sufficient memory to store the data required to execute back-propagation. Performing FL with conventional techniques therefore may require accepting unreasonable constraints on the allowed size of the data model or executing training outside of a trusted environment and subjecting the model to white-box vulnerability (i.e. vulnerability against attacks where an attacker has high knowledge of the attacked application, including e.g. access to source code).
Accordingly, approaches for federated learning with less computational burden on the client devices and higher security are desirable.
Various embodiments concern a method for training a machine learning model is described, including receiving, for each perturbation of a plurality of perturbations of model parameters of a starting version of the machine learning model, a change of loss of the machine learning model caused by the perturbation for a set of training data determined by feeding the set of training data to one or more perturbed versions of the machine learning model, estimating a gradient of the loss of the machine learning model with respect to the model parameters from the determined changes of loss and updating the starting version of the machine learning model to an updated version of the machine learning model by changing the model parameters in a direction for which the estimated gradient indicates a reduction of loss.
According to one embodiment, the method includes distributing the model parameters of the machine learning model to a plurality of clients for the clients to determines one or more of the changes of loss.
According to one embodiment, the server transmits one or more seeds to the clients for the clients to determine the perturbations using the one or more seeds.
According to one embodiment, the method comprises estimating the gradient of the loss of the machine learning model with respect to the model parameters from the changes of loss determined by the clients and updating the starting version of the machine learning model to the updated version of the machine learning model by changing the model parameters in a direction for which the estimated gradient indicates a reduction of loss.
According to one embodiment, the method includes performing multiple iterations including, in each iteration from a first to a last iteration, receiving, for each perturbation of a plurality of perturbations of model parameters of a respective starting version of the machine learning model, a change of loss of the machine learning model caused by the perturbation for a set of training data determined by feeding the set of training data to one or more perturbed versions of the machine learning model, estimating a gradient of the loss of the machine learning model with respect to the model parameters from the determined changes of loss and updating the respective starting version of the machine learning model to a respective updated version of the machine learning model by changing the model parameters in a direction for which the estimated gradient indicates a reduction of loss, wherein, for each iteration but the last iteration, the respective updated version of the machine learning model of the iteration is the starting version of the machine learning model for the next iteration.
According to one embodiment, the method includes estimating the gradient of the loss of the machine learning model with respect to the model parameters from the determined changes of loss according to Stein's identity.
According to one embodiment, the machine learning model is a neural network and the model parameters are neural network weights.
According to one embodiment, a server is provided configured to perform the method any one of the embodiments described above.
According to one embodiment, a computer program element is provided including program instructions, which, when executed by one or more processors, cause the one or more processors to perform the method of any one of the embodiments described above.
According to one embodiment, a computer-readable medium is provided including program instructions, which, when executed by one or more processors, cause the one or more processors to perform the method of any one of the embodiments described above.
The invention will be better understood with reference to the detailed description when considered in conjunction with the non-limiting examples and the accompanying drawings, in which:
The following detailed description refers to the accompanying drawings that show, by way of illustration, specific details and embodiments in which the disclosure may be practiced. These embodiments are described in sufficient detail to enable those skilled in the art to practice the disclosure. Other embodiments may be utilized and structural, and logical changes may be made without departing from the scope of the disclosure. The various embodiments are not necessarily mutually exclusive, as some embodiments can be combined with one or more other embodiments to form new embodiments.
Embodiments described in the context of one of the devices or methods are analogously valid for the other devices or methods. Similarly, embodiments described in the context of a device are analogously valid for a vehicle or a method, and vice-versa.
Features that are described in the context of an embodiment may correspondingly be applicable to the same or similar features in the other embodiments. Features that are described in the context of an embodiment may correspondingly be applicable to the other embodiments, even if not explicitly described in these other embodiments. Furthermore, additions and/or combinations and/or alternatives as described for a feature in the context of an embodiment may correspondingly be applicable to the same or similar feature in the other embodiments.
In the context of various embodiments, the articles “a”, “an” and “the” as used with regard to a feature or element include a reference to one or more of the features or elements.
As used herein, the term “and/or” includes any and all combinations of one or more of the associated listed items.
In the following, embodiments will be described in detail.
The system 100 includes a plurality of client devices 101 (also simply referred to as “clients” in the following) and a server device 102 (also simply referred to as “server” in the following) to which the clients 101 are connected via a communication network 103. The server 102 stores a machine learning model 104 which should be trained (also referred to as “server model” in the following). When the machine learning model 104 is considered to be centrally stored on the server 102 the clients can be seen as decentralized clients.
Federated learning (FL) allows decentralized clients to collaboratively train a server model. According to a standard training approach, in each of multiple training rounds (i.e. iterations), the clients 101 (or a selected subset of them) compute model gradients or (model) updates on their local private datasets, without explicitly exchanging sample points with the server 102. While FL with this training approach describes a promising blueprint and has several applications, it is gradient-based and thus requires the clients 101 to locally execute back-propagation, which leads to the following practical limitations:
In view of the above and in accordance with various embodiments, a system for back-propagation-free federated learning is provided in which back-propagation is replaced by multiple forward (or inference) processes to estimate gradients.
The system for back-propagation-free federated learning, in accordance with various embodiments, is
Experiments show that models trained by the system can achieve empirically comparable performance to conventional FL models.
As explained with reference to
As illustrated, the system for back-propagation-free federated learning according to various embodiments includes the following:
Experiments on the MNIST and CIFAR-10/100 datasets show that the system for back-propagation-free federated learning achieves comparable performance to conventional FL using a relatively small value of K (as determined by ablation studies) which shows that the system for back-propagation-free federated learning provides an effective back- propagation-free method for FL.
In the following, additional details of the system for back-propagation-free federated learning according to various embodiments are given.
It is assumed that there are C clients (e.g. C=10) and the c-th client's private dataset is c:={(Xic, yic)}i=1N
As mentioned above, in the standard FL training approach framework, clients 101 locally compute gradients {∇w(W; c)}c=1C or model updates through back-propagation and then upload them to the server. Federated average performs global aggregation using
where ΔWc is the local update obtained via executing Wc←Wc−η∇w
Gradient-based optimization techniques (either first-order or higher-order) may be used to train deep networks. Zero-order optimization methods may also be used for training, particularly when exact derivatives cannot be obtained or backward processes are computationally prohibitive.
Zero-order approaches require only multiple forward processes that may be executed in parallel. Along this routine, finite difference stems from the definition of derivatives and can be generalized to higher-order and multivariate cases by Taylor's expansion. For any differentiable loss function (W; ) and a small perturbation δ ∈ n, finite difference employs the forward difference scheme
(W+δ; )−(W; )=δT∇w(W+δ; )+ο(∥δ∥2) (2)
where δT∇w(W+δ; ) is a scaled directional derivative along 6. Furthermore, the central difference scheme can be used to obtain higher-order residuals as
(W+δ; )−(W; )=2δT∇w(W+δ; )+ο(∥δ∥22) (3)
Both left hand side terms of equations (2) and (3) can be seen as changes of loss, wherein the one of equation (2) is determined by the difference of the loss of a perturbed version of the model and the loss of a starting version (of the current iteration) of the model and the one of equation (3) is determined by the difference of the losses of two perturbed versions of the model (wherein one is perturbed with a perturbation δ and the other with the opposite −δ of the perturbation).
Finite difference formulas are typically used to estimate quantities such as gradient norm or Hessian trace, where δ is sampled from random projection vectors.
In the following, zero-order optimization techniques for FL and, in particular, the system for back-propagation-free federated learning are described in more detail. One possibility is to apply finite difference as the gradient estimator. To estimate the full gradients, each parameter ω ∈ W may be perturbed to approximate the partial derivative
causing the forward computations to grow with n (recall that W ∈ n) and thus making it difficult to scale to large machine learning models. In light of this, according to various embodiments, Stein's identity is used to obtain an unbiased estimation of gradients from loss differences calculated on various perturbations. As explained with reference to
Deep neural networks can be effectively trained if the majority of gradients have proper signs. Thus, according to various embodiments, where the machine learning model 104 is a deep neural network, forward propagation is performed multiple times on perturbed parameters, in order to obtain a stochastic estimation of gradients without back-propagation. Specifically, assuming that the loss function (W; ) is continuously differentiable with respect to W given any dataset , which is true (almost everywhere) for deep networks using non-linear activation functions, a smoothed loss function
∇wσ(W; )=δ˜(0,σ
is defined where the perturbation 8 follows a Gaussian distribution with zero mean and covariance σ2I. Given this, Stein's identity states that
where Δ(W, δ; ):=(W+δ; )−(W−δ; ) is the loss difference. It should be noted that computing a loss difference only requires the execution of two forward processes (e.g., forward passes through the machine-learning model) to compute (W+δ; ) and (W−δ; ) without back-propagation. It is straightforward to show that σ(W; ) is continuously differentiable for any σ>0 and vim; ∇wσ(W; ) converges uniformly as σ→0. Hence, it follows that
∇w(W; )=limσ→0∇wσ(W; ) (5)
Therefore, a stochastic estimation of gradients can be obtained using Monte Carlo approximation by 1) selecting a small value of σ; 2) randomly sampling K perturbations from (0, σ2I) as {σk}k=11 and 3) utilizing the Stein's identity of equation (5) to calculate
In the following, an exemplary algorithm is given.
The algorithm is based on the forward-only gradient estimator (W; ) according to equation (6). The algorithm includes
to the server 202, where each output is a floating-point number and the noise ϵC is negotiated by all clients to be zero-sum (i.e. to sum to zero over all the clients; for example, one client may receive an indication of the noises used by the other noises and set its noise such that the sum of noises including the client's noise is zero). The Bytes uploaded for K noisy outputs is 4×K;
Since {ϵc}c=1C it holds that
and equation (7) holds. Thus, the server 202 can correctly aggregate Δ(Wt, δk) and protect client privacy without recovering individual Δ(Wt, δk; c).
It should be noted that after calculating the gradient estimation (Wt), the server 202 updates the parameters to Wt+1 using techniques such as gradient descent with learning rate ij. The form of the system for back-propagation-free federated learning presented in the above algorithm corresponds to a federated optimization algorithm where lines 11-12 are executed once for each round t. The system for back-propagation-free federated learning can be generalized to an approach in which each client updates its local parameters in multiple steps using the gradient estimator (Wt, c) derived from Δ(Wt, δk; c)r) via equation (6) via gradient descent and uploads model updates to the server 202 which combines these updates to an aggregated update of the model.
Regarding convergence, it can be shown that (W; ) provides an unbiased estimation for the true gradients with convergence rate
It should be noted that an extremely small 6 will cause an underflow problem and a large K increases computational cost. So, for example, σ is set to 10−4 because it is a small value that does not cause numerical problems in exemplary use cases and works well on edge devices with half-precision floating-point numbers. K may be chosen in a broad range like 100 to 5000. It may be small (e.g., K=500) relative to the number of model parameters (which is e.g. 3.0×105).
Various embodiments may be used in different computing environments with different entities (e.g., client/server implementations), constraints, and/or use cases. Depending on the computing environment, various techniques may be used to improve accuracy, computational efficiency, or both.
Although either scheme may be used, according to some embodiments, the forward difference scheme (according to equation (2) is used twice forward difference (twice-FD) rather than the central scheme according to equation (3)) since experiments show that central scheme produces smaller residuals than the forward scheme by executing twice as many forward inferences, i.e. W±δ but the linearity of the forward difference scheme reduces the impact from second-order residuals.
In some embodiments, Hardswish activation function may be used as an alternative to ReLU in the machine learning model to overcome the issue of a value jump when the sign of feature changes after perturbation, i.e. h(W+δ)·h(W)<0 where h(.) denotes the feature mapping of the machine learning model.
Further, in some embodiments, exponential moving average (EMA) may be used to reduce oscillations caused by white noise. Regarding normalization, GroupNorm may be used as opposed to BatchNorm since on edge devices, the dataset size is typically small, which leads to inaccurate batch statistics estimation and degrades performance when using BatchNorm.
Since the system for back-propagation-free federated learning according to various embodiments only requires forward propagation, it can be executed in a TEE because it requires little memory. In general, model inference techniques in TEE may be exploited by slicing the computation graph and executing the per-layer forward calculation with constrained memory.
The trusted execution environment (TEE) 300 serves to protect against white-box attacks by preventing any model exposure. The TEE 300 protects both data and model security with three components: physical secure storage 301 to ensure the confidentiality, integrity, and tamper-resistance of stored data; a root of trust 302 to load trusted code and a separation kernel 303 to execute code in an isolated environment. Using TEEs, the federated learning system 100 is able to train deep models without revealing any model specifics. The memory is usually being too small (e.g., 90 MB) than what deep models require for back-propagation (e.g., ≥5 GB) but sufficient for forward propagation according to various aspects of the subject technology.
Membership inference attack and model inversion attack are two methods that require an attacker to be able to repeatedly perform model inference on specified data and obtain the results, such as confidence values or classification scores. Given that various aspects of the subject technology provide stochastic loss differences Δ(W, δ; ) associated with the random perturbation δ, it is difficult to perform inference attacks on systems implemented according to various aspects of the subject technology. It is difficult to distinguish between real data and random noise, indicating that attackers cannot obtain any useful information from outputs from such systems.
In each round's communication, each client 201 uploads a K-dimensional vector to the server 202 and downloads the updated global parameters. Since K is much less than the number of model parameters (e.g., 500 compared to 0.3 million), the system for back-propagation-free federated learning reduces data transfer by roughly half when compared to the pipeline of a back-propagation-based FL system. As to the epoch-level communication settings, a standard back-propagation-based FL system requires each client to perform model optimization on the local training dataset and upload the model updates to the server after a number of local epochs in order to reduce communication costs.
The system for back-propagation-free federated learning can also communicate at the epoch level with O(n) additional memory. An additional memory may be employed to store the perturbation in each forward process and estimate the local gradient using equation (6). After several epochs, each client 201 optimizes the local model with SGD (stochastic gradient descent) and uploads local updates. Compared to the back-propagation-based FL, good performance can be achieved with relatively modest value of K.
In summary, according to various embodiments, a method is provided as illustrated in
In 401, for each perturbation of a plurality of perturbations of model parameters of a starting version of the machine learning model, a change of loss of the machine learning model caused by the perturbation for a set of training data is received, wherein the change of loss is determined by feeding the set of training data to one or more perturbed versions of the machine learning model (which are versions of the starting version of the machine learning model perturbed in accordance with the perturbation (or its opposite, i.e. negative) and at least include the version of the starting version of the machine learning model perturbed in accordance with the perturbation).
In 402, a gradient of the loss of the machine learning model with respect to the model parameters is estimated from the determined changes of loss.
In 403, the starting version of the machine learning model is updated to an updated version of the machine learning model by changing the model parameters in a direction for which the estimated gradient indicates a reduction of loss.
According to various embodiments, in other words, rather than performing back-propagation, a machine learning model is trained according to an estimate of a gradient which is determined from the changes of loss caused by perturbations of the model parameters (and observed from forward passes through the perturbed versions of the machine learning model).
The perturbations are randomly generated (e.g., computed based on output generated by a random number generator).
The method of
A client, e.g. one of the clients 101, may for example carry out a method for training a machine learning model, comprising:
Determining, for each perturbation of a plurality of perturbations of model parameters of a starting version of the machine learning model, a change of loss of the machine learning model caused by the perturbation for a set of training data by feeding the set of training data to one or more perturbed versions of the machine learning model.
Optionally, the method may comprise estimating a gradient of the loss of the machine learning model with respect to the model parameters from the determined changes of loss.
The method may further include transmitting the determined changes of loss to a federated learning server or (in case the method comprises estimating the gradient of the loss) transmitting the estimated gradient to a federated learning server (or both).
According to one embodiment, the method comprises determining the change of loss by determining a perturbed version of the machine learning model whose model parameters are perturbed with respect to the starting version of the machine learning model in accordance with the perturbation and determining the change of loss as the difference of a loss of the starting version of the machine learning model and a loss of the perturbed version of the machine learning model.
According to one embodiment, the method comprises determining the change of loss by determining a first perturbed version of the machine learning model whose model parameters are perturbed with respect to the starting version of the machine learning model in accordance with the perturbation and a second perturbed version of the machine learning model whose model parameters are perturbed with respect to the starting version of the machine learning model in accordance with the opposite of the perturbation and determining the change of loss as the difference of a loss of the first perturbed version of the machine learning model and a loss of the second perturbed version of the machine learning model.
According to one embodiment, the method comprises updating the starting version of the machine learning model to a respective updated version of the machine learning model.
According to one embodiment, the method comprises transmitting the updated version of the machine learning model to a (e.g. federated learning) server for the server to combine the updated version of the machine learning model with one or more updated versions from other clients of the machine learning model to an aggregate update of the machine learning model.
According to one embodiment, the set of training data for different ones of the clients (including the client performing the method and the one or more other clients) are different.
After training the machine learning model, it may for example be used (e.g. by a corresponding controlling device) to control a technical system like e.g. a computer- controlled machine, like a robot (or robotic system), a vehicle, a domestic appliance or a manufacturing machine. According to the use case, the machine learning model's input may be sensor data of different types such as images, radar data, lidar data, thermal imaging data, motion data, sonar data etc. The training includes training input data according to the type of the machine learning model's input data type and labels (i.e. ground truth information) to determine the loss (and the changes of the loss).
The methods described herein may be performed and the various processing or computation units and the devices and computing entities described herein may be implemented by one or more circuits. In an embodiment, a “circuit” may be understood as any kind of a logic implementing entity, which may be hardware, software, firmware, or any combination thereof. Thus, in an embodiment, a “circuit” may be a hard-wired logic circuit or a programmable logic circuit such as a programmable processor, e.g. a microprocessor. A “circuit” may also be software being implemented or executed by a processor, e.g. any kind of computer program, e.g. a computer program using a virtual machine code. Any other kind of implementation of the respective functions which are described herein may also be understood as a “circuit” in accordance with an alternative embodiment.
While the disclosure has been particularly shown and described with reference to specific embodiments, it should be understood by those skilled in the art that various changes in form and detail may be made therein without departing from the spirit and scope of the invention as defined by the appended claims. The scope of the invention is thus indicated by the appended claims and all changes which come within the meaning and range of equivalency of the claims are therefore intended to be embraced.
Number | Date | Country | Kind |
---|---|---|---|
10202251220D | Sep 2022 | SG | national |
10202260574Y | Dec 2022 | SG | national |