The present disclosure relates generally to machine learning. More particularly, the present disclosure relates to personalized federated learning (PFL) via sharable basis models.
Recent years have witnessed a gradual shift in computer vision and machine learning from simply building a stronger model (e.g., image classifier) to taking more users' aspects into account. For instance, more attention has been paid to data privacy and ownership in collecting data for model training. Building models that are tailored to users' data, preferences, and characteristics have been shown to greatly improve user experience. Personalized federated learning (PFL) is a relatively new machine learning paradigm that can potentially fulfill the demands of both worlds. On the one hand, it follows the setup of federated learning (FL): training models with decentralized data held by users (i.e., clients). On the other hand, it aims to construct customized models for individual clients that would perform well for their respective data distributions.
While appealing, existing work of PFL has mainly focused on how to train the personalized models, e.g., via federated multi-task learning, meta-learning, fine-tuning, etc. In contrast, less attention has been paid to how to maintain the personalized models. Specifically, existing algorithms mostly require saving for each client a whole or partial model (e.g., a ConvNet classifier or feature extractor). This implies a linear parameter complexity with respect to the number of clients, which is parameter-inefficient and unfavorable for personalized cloud service—the cloud server needs a linear space of storage, not to mention the efforts for profiling, versioning, and provenance.
Learning parameters of a whole or partial model for each client has another issue when individual clients' data are scarce and distributionally skewed across classes. For instance, it is unlikely that a client can collect images of all possible object classes that would eventually show up in her environment. While federated learning itself enables collaboration among clients, (e.g., learning to recognize the missing objects from other clients' data), training model parameters specifically for each client is prone to over-fitting to each client's data distribution. In other words, the resulting personalized models are likely biased toward ignoring the rare or missing classes and thus are highly sensitive to class distribution changes during testing
Aspects and advantages of embodiments of the present disclosure will be set forth in part in the following description, or can be learned from the description, or can be learned through practice of the embodiments. The embodiments are directed towards providing personalized federated learning (PFL) models via sharable basis models. A model architecture and learning algorithm for PFL models is disclosed. The embodiments learn a set of basis models, which can be combined layer by layer to form a personalized model for each client using specifically learned combination coefficients. The set of basis models may be shared with each client of a set of the clients. Thus, the set of basis models is common to each client of the set of clients. However, each client may generate a unique PFL based on their specifically learned combination coefficients. The unique combination of coefficients for each client may be encoded in a separate personalized vector for each of the clients.
One example aspect of the present disclosure is directed to a computer-implemented method. The method includes a server device providing, each client device of a set of client devices, a set of untrained models. The server devices causes each client device of the set of client devices, to generate a separate set of trained models based on the set of untrained models. Each client device iteratively trains the set of untrained models based on a separate subset of a set of training data that is located locally on the client device. Each subset of the set of training data is inaccessible by the server device. Each subset of the set of training data is inaccessible by the client device except for the subset of training data that is located locally on the client device. The server device may receive a separate set of trained models from each client device. The server device may generate a set of basis models based on a combination of the separate set of trained models received from each of the client devices. The server device may provide the set of basis models to each client device of the set of client devices. The server device may cause each client of the set of client devices to generate a personalized model based on a separate linear combination of the basis models of the set of basis models.
Other aspects of the present disclosure are directed to various systems, apparatuses, non-transitory computer-readable media, user interfaces, and electronic devices.
These and other features, aspects, and advantages of various embodiments of the present disclosure will become better understood with reference to the following description and appended claims. The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate example embodiments of the present disclosure and, together with the description, serve to explain the related principles.
Detailed discussion of embodiments directed to one of ordinary skill in the art is set forth in the specification, which makes reference to the appended figures, in which:
Reference numerals that are repeated across plural figures are intended to identify the same features in various implementations.
To address the deficiencies of conventional personalized federated learning (PFL) discussed throughout, a model architecture and learning algorithm for personalized federated learning (PFL) is disclosed. The embodiments learn a set of basis models, which can be combined layer by layer to form a personalized model for each client using specifically learned combination coefficients. The set of basis models may be shared with each client of a set of clients. However, each client may generate a unique PFL based on their specifically learned combination coefficients. The unique combination of coefficients for each client may be encoded in a separate personalized vector for each of the clients. Thus, the set of basis models is common to each client of the set of clients. That is, although the basis models (or basis functions) are shared amongst the clients, each client of the set of clients may have their own linear combination of the basis models (e.g., the expansion coefficients of the linear combination may be encoded in the personalized vector for the client), that serves as a personalized for the set of models.
This model architecture bypasses the linear parameter complexity without increasing the inference time. In some embodiments, the basis models may be trained via a federated averaging procedure (e.g., iterating between local model training for multiple epochs and global aggregation). However, other embodiments may additionally include a coordinate descent style training algorithm with combination coefficient sharpening. This architecture and/or method(s) may be referred to as a federated basis model (e.g., a sharable federated basis model) throughout. A federated basis model not only enjoys its built-in parameter efficiency but also maintains high personalized classification accuracy. Moreover, by learning the shareable basis models and using them to construct personalized models, it is shown that the embodiments are significantly more robust in coping with class distribution changes.
A great portion of the work in conventional federated learning (GFL) has focused on the generic setting: collaboratively training a single “global” model. The sharable federated basis model of the embodiments is an algorithm that iterates between local model training and global model aggregation. Most PFL models (e.g., excluding the federated basis models of the embodiments) have limitations arising from the potential discrepancy among clients' data distributions (i.e. , the non-independent and identically distributed (IID) condition), which makes the local models diverge from each other. Conventional GFL-based algorithms have been proposed to improve various PFL models . For instance, global aggregation has been proposed to replace weight average by model ensemble and distillation. Local training has been proposed to employ regularization or control varieties to adjust or correct local gradients.
In contrast to GFL, PFL models (e.g., including the personalized federated basis models of the embodiments) take into account the discrepancy among clients and learns for each client a personalized model tailored to her data distribution. Many conventional PFL approaches are based on multi-task learning (MTL), which leverages the clients' task relatedness to improve model generalizability. For instance, some conventional PFL approaches encourage related clients to learn similar models; regularized local models with a learnable global model, prior, or set of data logits; designed the model architecture to have personalized and shareable components. Some other conventional PFL approaches are based on mixture models, with (separately) learned global and personalized models and performed a mixture of them for prediction. Meta-learning is also applied to learn an initialized model that can be adapted to each client rapidly. It is worth noting that all these conventional PFL approaches require saving for each client the parameters of a whole or partial model.
In contrast to some conventional PFL approaches, the embodiments formulate each task model as a linear combination of a set of basis models. This makes the embodiments clearly different from conventional PFL approaches, as the embodiments bypass the linear parameter complexity. Some conventional PFL approaches learn basis models that can be used to initialize or regularize personalized models. However, these conventional PFL approaches still need a linear parameter complexity for the final personalized models.
Aspects of the present disclosure provide a number of technical effects and benefits. For instance, the embodiments include a novel model architecture (e.g., a sharable federated basis model) and learning algorithm, which alleviates the linear model complexity in PFL and improves the distributional robustness. More specifically, conventional PFL approaches have a linear parameter complexity and each client keeps a model. In the herein embodiments, the clients only keep coefficients for combining a few basis models. That is, the embodiments provide separate set of expansion coefficients for a set of sharable basis models to each client of a set of clients. The embodiments bypass the linear parameter complexity in maintaining personalized models and overcome their vulnerability to class distribution changes. To this end, the embodiments are directed to a novel sharable federated basis model architecture and algorithm, which constructs each personalized model by a few, shareable basis models. An enhanced training algorithm is designed systematically and mathematically soundly to overcome the collapse problem of basis models due to local training. Along with federated basis models, a new and carefully designed PFL benchmark (PFLBed) is presented herein. Empirical studies presented herein demonstrate the effectiveness of federated basis models.
Each of the client devices of the set of client devices 110 may implements a personalized basis model learner 130. The server device 140 may implement a federated basis model learner 140. As discussed throughout, each of the personalized basis model learner 130 and the federated basis model learner 140 contribute to the training of each basis model of a set of basis model 150. In the example embodiment shown in
Each client device of the set of client devices 110 may locally store a separate (e.g., a unique) subset of a set of training data 120. For instance, the first client device 112 may locally store a first subset of training data 122, the second client device 114 may locally store a second subset of training data 124, and the third client device 116 may locally store a third subset of training data 126. As shown in
Each client device of the set of client devices 110 may implement a personalized federated learning (PFL) model. The PFL model for a client device is a personalized linearly weighted combination of the set of basis models 150. To generate a PFL model, a client device may learn and/or generate a personalized vector. A personalized vector for a client device may indicate the linear combination of the basis models of the set of basis models 150 for the client device. The values of the components may indicate the weights for the linear combination of the set of basis models. As such, each component of a personalized vector may correspond to a particular basis model of the set of basis models 150. Thus, the dimensionality of a personalized vector may be equivalent to the cardinality of the set of basis models 150. In the non-limiting example of
In addition to at least partially training each basis model of the set of basis models 150, a client device may employ the personalized basis model learner 130 to learn its own personalized vector. For example, the first client device 112 learns a first personalized vector 132, the second client device 114 may learn a second personalized vector 134, and the third client device 116 may learn a third personalized vector 136. In various embodiments, the data privacy or a personalized vector (and this the data privacy of a PFL model) is ensured. For instance, the server device 102 and/or the federated basis model learner 140 may not have access to any of the personalized vectors of the client devices of the set of client devices 110. Furthermore, each client device may only have access to its own personalized vector. For instance, first client device 112 may access its own personalized vector (i.e., first personalized vector 132), however, the second personalized vector 134 of the second client device 114 and the third personalized vector 136 of the third client device 116 may be inaccessible to the first client device 112.
As a general example, a classifier hθ=gw∘ƒϕ may be learned, where ƒϕ is the feature extractor parameterized by ϕ and gw is the classification head parameterized by w. θ is employed to denote {ϕ, w}.
In centralized learning, the training set ={(x1, y1), . . . , (xN, yN)} is given, where x is the input (e.g., an image) and y ϵ {1, . . . , C}=[C] is the truth label. Given the loss function (e.g., a cross-entropy loss function), a typical way to learn θ may include minimizing the empirical risk :
In generic federated learning (GFL), the goal remains the same—to train a “global” model hθ. However, the training data are now collected and separately stored by M clients: each client m ϵ [M] keeps a private set m={(xi, yi)}i=1|
Unfortunately, equation 2 cannot be solved directly since the data are decentralized. One standard solution is federated averaging (FedAvg), which decomposes the optimization into a multi-round process. Within each round, the server first broadcasts the “global” model to the clients. The clients then perform local training in parallel to update the model by minimizing each client's empirical risk. The “local” models are then returned to the server and globally aggregated into an updated “global” model by element-wise averaging over local model parameters. Let
Local training is often implemented by stochastic gradient descent (SGD). The fewer the gradient steps are per round, the closer the resulting
In contrast to GFL, personalized federated learning (PFL) aims to learn for each client m a customized model θm, whose goal is to perform well on client m's local training data. While there is no standard objective function, the optimization problem may be defined as:
where is a regularizer; Ω is introduced to relate clients. The regularizer is imposed to encourage related clients to learn similar models, to overcome their limited data. In contrast to equation (2), equation (4) seeks to minimize each client's empirical risk (plus a regularizer) by the corresponding personalized model θm. In practice, PFL algorithms often run iteratively between the local and global steps as well, so as to update Ω periodically according to clients' models.
Some embodiments may perform fine-tuning to the global model of FedAvg to generate personalized models. In such embodiments, the global aggregation of FedAvg may serve as a strong implicit regularizer.
Personalized Federated Learning with Basis Models; Formulation
While both solving equation (4) and fine-tuning the FedAvg 's global model can lead to personalized models, they require learning and saving the parameters of a whole (or partial) model for each of the M clients—i.e., linear parameter complexity (M×|θ|). This is particularly unfavorable for maintaining the models, especially when a huge number of clients are involved, and the personalized models are eventually operating on the cloud. Besides, model parameters learned specifically for each client would inevitably adapt to the client's data distribution, even with regularization. If the distribution is skewed across classes, the resulting personalized models would be vulnerable to class distribution changes.
To resolve these issues in PFL, the embodiments employ a novel method to construct personalized models to bypass the linear parameter complexity. Each personalized model (e.g., θm) may be represented as the linear combination of basis models:
θm=Σk αm[k]×vk, (5)
where ={v1, . . . , vK} is a set of K basis models shareable among clients, and αm ϵ ΔK−1 is a K-dimensional vector on the (K−1)-simplex that records the personalized convex combination coefficients. That is, each personalized model is a convex combination of a set of basis models.
With this representation, the total parameters to save for all clients become
(K×|θ|+K×M)≃(K×|θ|). (6)
Here, (K×M) corresponds to all the combination coefficients ={α1, . . . , αM}, which is negligible since for most of the modern neural network models, |θ|>>M. It is noted that when K=M and the combination vectors are all one-hot with αm[m]=1, this representation reduces to saving for each client a model. However, when clients' data share similarity—a common assumption made in multi-task learning—it is likely K<<M can be used to construct high-quality personalized models and meanwhile largely reduce the number of parameters.
Objective function. Building upon the model representation in equation (5) and the optimization in equation (5), PFL problem for the embodiments may be represented as:
It is noted that both the basis models and combination coefficient vectors are to be learned. The regularization term in equation (4) may be dropped, as the convex combination representation itself may be a form of regularization.
Training. The optimization of equation (7) is discussed below.
Inference. The convex combination takes place in the model parameter space. Before making predictions, a single personalized model θm is first constructed by convexly combining the parameters of basis models in layer-by-layer. The inference time on each image thus remains the same as existing PFL methods. This is sharply different from the conventional mixture of experts procedure, which combines the predictions of expert models, not their parameters. Namely, it needs to perform multiple times of inference on each image before the final prediction can be made. The various embodiments may be differentiated from conventional approaches at least because the embodiments extending such concepts to PFL, identifying difficulties in optimization, and resolving them accordingly.
Personalized Federated Learning with Basis Models: Baseline Training
Similarly to equation (2), equation (7) cannot be solved directly since the clients' data are decentralized. In this subsection a baseline training algorithm is presented.
Baseline training algorithm. As the basis models are shared among clients, they can be conceptually considered as global models. A FedAvg -style training algorithm is developed by iterating between local and global steps/
Here, m (α, ) is used as a concise notation for m(θ=Σk α[k]×vk). It is worth noting that in local training, client m only updates her own coefficients αm(t), not others'; every client can potentially update all basis models in . α is implemented by a softmax function to ensure that it learns a convex combination. The final personalized model for client m, after T rounds, is θm=Σk αm(T)[k]×
Brief experimental setup. To analyze the baseline algorithm, a PFL experiment is conducted. The PACS dataset is employed, which contains in total 7K training images from 7 classes. The procedure detailed in the below section is followed to split the training images into M=40 clients. Each client has images from one of the four domains (Photo, Art, Cartoon, Sketch); the class distribution (y) of each client m is sampled from a Dirichlet distribution to make it skewed and not identical among clients. These strategies create highly heterogeneous clients' data. K=4 bases are used and model each by a ResNet-18. Each basis model is randomly initialized. For more details, please see s_pflbed and s_exp.
For evaluation, a class-balanced global test set is prepared for each image domain. Without loss of generality, it is assumed that each test set has the same number of test images N, and each test sample is indexed by j. To evaluate the accuracy of client m, the global test set of client m's domain is used. Specifically, two accuracies are calculated as:
The personalized accuracy weighs each test sample by (yj) to reflect the class distribution of client m's training data. This can be considered as the standard personalized accuracy in literature. The balanced accuracy, in contrast, treats each test sample of client m's domain equally. This is to simulate the situation that a client does not have sufficient resources to collect her training data to faithfully reflect the class distribution in her environment. In this case, a class-balanced test set may be used to assess the personalized model's distributional robustness. To summarize over clients, the average may be taken over their accuracy.
Unlimited communication. In terms of the number of local gradient steps per round and the number of total rounds, an ideal case that includes unlimited communication is first considered. This allows a performance of global aggregation as soon as soon as possible; i.e., after each mini-batch SGD step. It is noted that this training strategy is basically equivalent to mini-batch SGD in centralized learning.
In other words, under the ideal case, the capacity and capability of the convex combination representation of personalized models is justified. The distributional robustness may be attributed to the collaboratively learned and shared basis models, which are less likely to be biased/over-fitted to the skewed local training data.
Limited communication. In practice, due to communication constraints, it is infeasible to perform global aggregation after each mini-batch SGD. Thus, the standard case is studied by performing local training for a few epochs per round. Table 300 (every 5 epochs columns) summarizes the results (with 100 rounds). Almost all algorithms degrade. Specifically, following further observations may be attributed to the data of table 300:
Analysis. To have a better understanding of the issue, the training dynamics are investigated. Specifically, both a) the average pairwise cosine similarity between the basis model parameters; and b) the entropy of the learned combination vectors are checked. A high entropy implies an almost uniform combination vector.
By taking a deeper look at
∇v
∇α[k]m(α, )=vk·∇θm(θ). (10)
Interesting, while with different magnitudes, it is found that ∇v
It is noted that this phenomenon does not appear in the ideal case because global aggregation is performed right after each mini-batch SGD step, which prevents the above-mentioned similarity from accumulating.
Personalized Federated Learning with Basis Models: Enhanced Training
To prevent the collapse problem in the federated basis models of the embodiments which is due to multiple steps or epochs of local training per round, the following treatments are proposed for the various embodiments.
Coordinate descent for the combination coefficients and bases. Within each round, a first update α (for multiple SGD steps) while freezing is proposed, and then update (for multiple SGD steps) while freezing α. It is noted that at the beginning of each round of local training, vk·∇θm(θ) is not necessarily positive. Updating α with frozen thus could potentially enlarge the difference among elements in α: forcing the personalized model to attend to a subset of bases. After starting to update , α may be frozen to prevent the collapse problem.
Sharpening combination coefficients and regularizing bases. Since α[k]≥0, updating vk locally with θv
During sharpening, a temperature parameter 1≥τ≥0 may be introduced and α[k] may be calculated by
In some embodiments, this is only performed temporally wile calculating ∇v
Another way to alleviate the collapse problem may be investigated as follows. Suppose at the beginning of each round of local training, bases in the newly broadcast (t−1) are specialized, then one way to preserve their specialized knowledge during local training is via basis-wise regularization, towards the broadcast bases.
Improved training algorithm. Putting these treatments together, an improved training algorithm for the embodiments based on equation (8) is presented below:
Here, ∥·∥F is the Frobenius norm. The personalized model for client m, after T rounds, is θm=Σk αm(T)[k]×
Personalized Federated Learning with Basis Models: An Extension
So far, the same coefficient α[k] are applied to combine vk into θm (cf. equation (5)). Such formula can be slightly generalized to decouple coefficients for the feature extractors and classification heads. For instance, some clients may have the same image styles but different class distributions. Recall hθ
There have been many efforts on building datasets for generic FL, but how should construct a reliable dataset be constructed along with the evaluation protocols for PFL algorithm development? Consider the following aspects:
Cross-domain with non-IID (x, y). A challenging personalized dataset should have the joint distribution (x, y) differ from client to client, not just (x) (e.g., styles, domains) or (y) (i.e. , class labels). Both the training data sizes and the class distributions should be skewed among clients.
Sufficient test samples and matched training/test splits. The test set should be large enough for reliable evaluation. This is challenging when there are many clients, each with a small data size. For example, the popular 62-class hand-written character FEMNIST dataset only has 226 images for each writer on average; many classes only have ≤1 image. It is unfaithful to split each client into train/test sets due to mismatches on (y). Indeed, a large discrepancy
is found even with a 50%/50% split.
Distributional robustness. As discussed above, the balanced accuracy to evaluate distributional robustness for the case of object recognition inspired by the practice in class-imbalanced learning may be included. It is noted that changing only the testing (y) but not (x) (domains).
To achieve these desired properties, it is proposed to transform a cross-domain dataset into clients' sets {(mtrain, mtest/val)} with the following procedures.
Examples. Three example object recognition datasets are considered, including PACS (7 classes for 10K images), VLCS (5 classes for 10K images), and Office-Home (65 classes for 15K images).
Method 500 begins at block 502, where a server devices provides a set of untrained models to each client device of a set of client devices. At block 504, the server devices may cause each client device of the set of client devices to generate a separate set of trained models based on the set of untrained models. Generating a set of trained models at a client device may include the client device iteratively training the set of untrained models based on a separate subset of a set of training data that is located locally on the client device. Each subset of the set of training data is inaccessible by the server device. Each subset of the set of training data is inaccessible by the client device except for the subset of training data that is located locally on the client device. At block 506, the server device may receive a separate set of trained models from each client device. At block 508, the server device may generate a set of basis models. The set of basis models may be based on a combination of the separate set of trained models received from each of the client devices. At block 510, the server device may provide the set of basis models to each client device of the set of client devices. At block 512, the server device may cause each client device of the set of client device to generate a personalized model based on a separate linear combination of the set of basis models.
In various embodiments, the method further includes causing, each client of the set of client devices, to iteratively generate a personalized vector while iteratively training the set of untrained models based on the separate subset of the set of training data. The personalized vector of a client device may indicate the separate linear combination of the set of basis models of the client device. The server device may cause, each client of the set of client devices, to generate the personalized model further based on the personalized vector of the client device and the set of basis models.
In some embodiments, each untrained model of the set of untrained models has an identical model architecture. Each untrained model of the set of untrained models may be an image classifier model. In such embodiments, the set of training data may include labeled images. The separate linear combination of the set of basis models for each client device of the set of client devices may be a convex combination of the set of basis models.
In various embodiments, iteratively training the set of untrained models at a client device of the set of client devices may include iteratively determining components of a personalized vector for the client device. Iteratively determining the components of the personalized vector may be based on the client device's separate subset of a set of training data and a loss function. The personalized vector for the client device may indicate the separate linear combination of the set of basis models for the client device. Parameters for each untrained model of the set of untrained models may be determined based on the client device's separate subset of a set of training data and the components of the personalized vector for the client. Iteratively determining components of the personalized vector for the client device may include setting a threshold value for each component of the personalized vector. For each iterative determination of each component of the personalized vector, a determined value of the component may be zeroed-out when the determined value of the component is less than the threshold value for the determined value of the component.
In various embodiments, iteratively training the set of untrained models at a client device of the set of client devices may include coordinating a first stochastic gradient descent (SGD) process for the set of untrained models and a second SGD process for a personalized vector for the client. Coordinating a first SGD process and a second SGD process may include holding constant components of the personalized vector while performing the first SGD process. Parameters of the set of untrained models may be held constant while performing the second SGD process. In at least one embodiments, the set of basis models may include a first subset of basis models and a second subset of basis models. The first subset of basis models may correspond to a feature extractor of the personalized model. The second subset of basis models may correspond to a classification head of the personalized model.
The technology discussed herein makes reference to servers, databases, software applications, and other computer-based systems, as well as actions taken, and information sent to and from such systems. The inherent flexibility of computer-based systems allows for a great variety of possible configurations, combinations, and divisions of tasks and functionality between and among components. For instance, processes discussed herein can be implemented using a single device or component or multiple devices or components working in combination. Databases and applications can be implemented on a single system or distributed across multiple systems. Distributed components can operate sequentially or in parallel.
While the present subject matter has been described in detail with respect to various specific example embodiments thereof, each example is provided by way of explanation, not limitation of the disclosure. Those skilled in the art, upon attaining an understanding of the foregoing, can readily produce alterations to, variations of, and equivalents to such embodiments. Accordingly, the subject disclosure does not preclude inclusion of such modifications, variations and/or additions to the present subject matter as would be readily apparent to one of ordinary skill in the art. For instance, features illustrated or described as part of one embodiment can be used with another embodiment to yield a still further embodiment. Thus, it is intended that the present disclosure cover such alterations, variations, and equivalents.
The present application claims the benefit of priority of U.S. Provisional Application Ser. No. 63/410,473, filed on Sep. 27, 2022, titled PERSONALIZED FEDERATED LEARNING VIA SHARABLE BASIS MODELS, which is incorporated herein by reference.
Number | Date | Country | |
---|---|---|---|
63410473 | Sep 2022 | US |