METHOD AND SYSTEM FOR FEDERATED LEARNING

Information

  • Patent Application
  • 20240135194
  • Publication Number
    20240135194
  • Date Filed
    November 17, 2023
    5 months ago
  • Date Published
    April 25, 2024
    13 days ago
  • CPC
    • G06N3/098
  • International Classifications
    • G06N3/098
Abstract
Broadly speaking, embodiments of the present techniques provide a method for training a machine learning, ML, model to update global and local versions of a model. We propose a novel hierarchical Bayesian approach to Federated Learning (FL), where our models reasonably describe the generative process of clients' local data via hierarchical Bayesian modeling: constituting random variables of local models for clients that are governed by a higher-level global variate. Interestingly, the variational inference in our Bayesian model leads to an optimisation problem whose block-coordinate descent solution becomes a distributed algorithm that is separable over clients and allows them not to reveal their own private data at all, thus fully compatible with FL.
Description
BACKGROUND
1. Field

The present application generally relates to a method and system for federated learning. In particular, the present application provides a method for training a machine learning, ML, model to update global and local versions of a model without a central server having to access user data.


2. Description of Related Art

These days, many clients/client devices (e.g. smartphones), contain a significant amount of data that can be useful for training machine learning, ML, models. There are N clients with their own private data Di, i=1, . . . , N. Usually the client devices are less powerful computing devices with small data D i compared to a central server. In traditional centralised machine learning, there is a powerful computer that can collect all client data D=∪i=1NDi and train a model with D. Federated Learning (FL) aims to enable a set of clients to collaboratively train a model in a privacy preserving manner, without sharing data with each other or a central server. That is, in federated learning (FL), it is prohibited to share clients' local data as the data are confidential and private. Instead, clients are permitted to train/update their own models with their own data and share the local models with others (e.g., to a global server). Then, FL is all about how to train clients' local models and aggregate them to build a global model that is as powerful as the centralised model (global prediction) and flexible enough to adapt to unseen clients (personalisation). Compared to conventional centralised optimisation problems, FL comes with a host of statistical and systems challenges—such as communication bottlenecks and sporadic participation. The key statistical challenge is non-independent and non-identically distributed (non-i.i.d) data distributions across clients, each of which has a different data collection bias and potentially a different data labeling function (e.g., user preference learning). However, even when a global model can be learned, it often underperforms on each client's local data distribution in scenarios of high heterogeneity. Studies attempted to alleviate this by personalising learning at each client, allowing each local model to deviate from the shared global model. However, this remains challenging given that each client may have a limited amount of local data for personalised learning.


The applicant has therefore identified the need for an improved method of performing federated learning.


SUMMARY

According to an embodiment of the disclosure, there is provided a method for training, using federated learning, a global machine learning, ML, model for use by a plurality of client devices. The method comprises defining, at a server, a Bayesian hierarchical model which links a global random variable with a plurality of local random variables; one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices and approximating, at the server, the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the server, each of the plurality of local ML models is associated with one of the plurality of client devices, and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model. The method further comprises sending, from the server, the global parameter to a predetermined number of the plurality of client devices; receiving, at the server from each of the number of the plurality of client devices, an updated local parameter, wherein each updated local parameter has been determined by training, on the client device, the local ML model using a local dataset, and wherein during training of the local ML model, the global parameter is fixed; and training, at the server, the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed


In other words, there is provided a method for training, on a server, a machine learning, ML, model using a hierarchical Bayesian approach to federated learning. The method described above may be considered to use block-coordinate optimization, because it alternates two steps: (i) updating/optimizing all local parameters while fixing the global parameter and (ii) updating the global parameter with all local parameters fixed. The updating of the local parameters uses the local dataset but the updating of the global parameters uses the local parameters but not the local datasets. Thus, the local datasets remain on the client device and there is no sharing of data between the client device and the server. Data privacy can thus be respected. The local parameters are not sent from the server to the client device when the global parameter is sent. In other words, only the global parameter is sent to each client device involved in a round of federated learning.


According to an embodiment of the disclosure, there is provided a system for training using federated learning, a global machine learning, ML, model, the system comprising: a server comprising a processor coupled to memory, and a plurality of client devices each comprising a processor coupled to memory. The processor at the server is configured to define a Bayesian hierarchical model which links a global random variable with a plurality of local random variables; one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices and approximate the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the server, each of the plurality of local ML models is associated with one of the plurality of client devices, and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model. The processor at the server is further configured to send the global parameter to a predetermined number of the plurality of client devices; receive, from each of the predetermined number of the plurality of client devices, an updated local parameter; and train the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed. The processor at each of the client devices is configured to receive the global parameter; train the local ML model using a local dataset on the client device to determine an updated local parameter, wherein during training of the local ML model, the global parameter is fixed; and send the updated local parameter to the server.


The following features apply to both aspects.


A Bayesian hierarchical model (also termed hierarchical Bayesian model) is a statistical model written in multiple levels (hierarchical form) that estimates parameters of a posterior distribution using the Bayesian method. A Bayesian hierarchical model makes use of two important concepts in deriving the posterior distribution, namely: hyperparameters (parameters of the prior distribution) and hyperpriors (distributions of hyperparameters). The global random variable may be termed shared knowledge, a hyperparameter or a higher-level variable and may be represented by ϕ. The local random variables may be termed individual latent variables or network weights and may be represented by {θi}i=1N where N is the number of the plurality of client devices. The global random variable may link the local random variables by governing a distribution of the local random variables. For example, the distribution (also termed prior) for the global random variable ϕ and the local random variables {θi}i=1N may be formed in a hierarchical manner as:









p

(

ϕ
,

θ
1

,


,

θ
N


)

=


p

(
ϕ
)






i
=
1

N



p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)

.








The posterior distibution may be defined as






p(ϕ, θ1:N|D1:N)


where Di is the dataset at each client device. Applying Bayes' rule, the posterior distribution is proportional to a product of separate distributions for the global random variable and the local random variable. For example, the posterior distribution may be expressed as being proportional to






p(ϕ)πi=1Npi |ϕ)p(Dii)


where p(ϕ) is the distribution for ϕ, p(θi|ϕ) is the distribution for θi given ϕ and p(Dii) is the distribution of each dataset given θi. Although the posterior distribution can be well defined, it is difficult to solve and thus an approximation is used.


Approximating the posterior distribution may comprise using variational inference, for example the posterior distribution may be approximated using a density distribution q(ϕ,θ1, . . . , θN; L) which is parameterized by L. Approximating the posterior distribution may further comprise factorising the density distribution, e.g.










q

(

ϕ
,

θ
1

,


,


θ
N

;
L


)

:=


q

(

ϕ
;

L
0


)






i
=
1

N



q
i

(


θ
i

;

L
i


)




,





where q(ϕ; L0) is the global ML model which is parameterised by the global parameter L0 and qii; Li) which is the local ML model for each client device and which is parameterised by the local parameter Li. The global parameter may thus be termed a global variational parameter. Similarly, the local parameter may be termed a local variational parameter. It will be appreciated that each global parameter and each local parameter may comprise a plurality of parameters (e.g. a vector of parameters).


By separately modelling the global parameter and the local parameters, it is possible to have different structures for at least some of the global ML model and the local ML models. In other words, the global ML model may have a different backbone from one or more of the local ML models. Similarly, the local ML models may have the same or different backbones. These different structures and/or backbones can be flexibly chosen using prior or expert knowledge about the problem domains on hand.


It will be appreciated that the sending, receiving and determining steps may be repeated multiple times (e.g. there are multiple rounds) until there is convergence of the global ML model and the local ML models. The number of client devices which receive the global parameter at each sending step may be lower than or equal to the total number of client devices.


Training the global ML model comprises optimising using a regularization term which penalises deviation between the updated global parameter and the global parameter which was sent to the client devices (i.e. between the updated global parameter and the previous version of the global parameter). Training the global ML model may comprise optimising using a regularization term which penalises deviation between the updated global parameter and each of the received local parameters. The optimisation may be done using any suitable technique, e.g. stochastic gradient descent (SGD). The regularization term may be any suitable term, e.g. a Kulback-Leibler divergence. Thus, one possible expression for the training of the global ML model may be:










min

L
0







i

N
f





E

q

(

ϕ
;

L
0


)


[

KL
(



q
i

(


θ
i

;

L
i


)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)




]


+

KL

(


q

(

ϕ
;

L
0


)





p

(
ϕ
)



)






where Eq(ϕ; L0) is an evidence lower bound function, qii; Li) is the local ML model for each client device, θi is the local random variable for the ith client device, Li is the local parameter for the ith client device, p(θi|ϕ) is the prior for each local random variable θi given the global random variable ϕ, Nf is the number of client devices which received the global parameter L0, KL represents each regularisation term using a Kulback-Leibler divergence, q(ϕ; L0)) is the global ML model parameterised by the global random variable ϕ and p(ϕ) is the prior for ϕ.


Training, using a local dataset on the client device, may comprises optimising using a loss function to fit each local parameter to the local dataset. Any suitable function which fits the local ML model to the local dataset may be used. Training, using a local dataset on the client device, may comprises optimising using a regularisation term which penalises deviation between each updated local parameter and a previous local parameter. As for the training at the server, the optimisation may be done using any suitable technique, e.g. stochastic gradient descent (SGD). The regularization term may be any suitable term, e.g. a Kulback-Leibler divergence.


Thus, one possible expression for the training of the local ML model may be:










min

L
i





E


q
i

(


θ
i

;

L
i


)


[


-
log



p

(


D
i





"\[LeftBracketingBar]"


θ
i



)


]


+


E

q

(

ϕ
;

L
0


)


[

KL
(



q
i

(


θ
i

;

L
i


)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)




]






where Eq(ϕ; L0) and Eqii; Li) are evidence lower bound functions, qii; Li) is the local ML model for each client device, θi is the local random variable for the ith client device, Li is the local parameter for the ith client device, p(Dii) is the likelihood of each dataset Di given each local random variable θi, p(θi|ϕ) is the distribution for each local random variable θi given the global random variable ϕ, and KL represents the regularisation term using a Kulback-Leibler divergence.


Approximating the posterior distribution may comprise using a Normal-Inverse-Wishart model. For example, a Normal-Inverse-Wishart model may be used as the global ML model and a global mean parameter and a global covariance parameter may be used the global parameter. A mixture of two Gaussian functions may be used the local ML model and a local mean parameter may be used as the local parameter. When using a Normal-Inverse-Wishart model, the training of the global ML model may be expressed as:










m
0
*

=


p

N
+
1







i
=
1

N


m
i




,










V
0
*

=



n
0


N
+
d
+
2




(



(

1
+

N


ϵ
2



)


I

+



m
0
*

(

m
0
*

)



+




i
=
1

N


ρ

(


m
0
*

,

m
i

,
p

)



)







In other words, the updated global mean parameter m*0 may be calculated from a sum of the local mean parameters mi for each of the client devices. The updated global mean parameter m*0 is proportional to the sum where the factor is the p is the user-specified hyperparameter where 1−p corresponds to the dropout probability divided by one more than the total number N of client devices. The updated global covariance parameter V*0 may be calculated from the sum above, where n0 is a scalar parameter at the server, N is the total number of client devices, d is the dimension, ϵ is a tiny constant, I is the identity matrix, m*0 is the updated global mean parameter, m0 is the current global mean parameter, mi is the local mean parameter for each of the client device and p is the user-specified hyperparameter, and





ρ(m0, mi, p)=pmimiT−pm0miT−pmim0T+m0m0T.


When using a Normal-Inverse-Wishart model, the training of the local ML model may be expressed as:











min

m
i






i

(

m
i

)


:=



-
log



p

(


D
i





"\[LeftBracketingBar]"



m
~

i



)


+


p
2



(


n
0

+
d
+
1

)




(


m
i

-

m
0


)






V
0

-
1


(


m
i

-

m
0


)




,





where custom-characteri is the local parameter represented by mi the local mean parameter, p(Di|{tilde over (m)}i) is the distribution of the local dataset Di given a dropout version {tilde over (m)}i of the local mean parameter, p is the user-specified hyperparameter, n0 is a scalar parameter, d is the dimension, m0 is the current global mean parameter (which is fixed) and V0 is the current global covariance parameter (which is fixed). In other words, the training (i.e. optimisation) at both the server and each client device is greatly simplified by the use of the Normal-Inverse-Wishart model. In summary, each client i a priori gets its own network parameters θi as a Gaussian-perturbed version of the shared global mean parameters μ from the server, namely θi|ϕ˜custom-character(μ, Σ). This is intuitively appealing, but not optimal for capturing more drastic diversity or heterogeneity of local data distributions across clients.


As an alternative to using a Normal-Inverse-Wishart model, the method may use a mixture model which comprises multiple different prototypes (e.g. K) and each prototype is associated with a separate global random variable so that ϕ={μ1, . . . , μK}. In other words, a prototype is a component in the mixture. Multiple different global mixture components can represent different client data statistics/features. Such a model may be more useful where clients' local data distributions, as well as their domains and class label semantics, are highly heterogeneous. When using the mixture model, the global ML model may be defined as a product of a fixed number of multivariate Normal distributions wherein the fixed number is determined by the number of prototypes which cover the client devices data distributions. The global model may be defined for example using









q

(

ϕ
,

L
0


)

=




j
=
1

K


𝒩
(




μ
j

;

r
j


,


ϵ
2


I



)







where custom-character is a multivariate normal distribution, μj is the global random variable for each network (in other words ϕ={μ1, . . . , μK}), {rj}j=1K are variational parameters representing the global parameter L0 and ϵ is near 0. Each local model may then be chosen from a network (which may also be termed a prototype).






qi)=custom-characteri; mi, ϵ2I),


where custom-character is a multivariate normal distribution, θi is the local parameter for each client device, mi is the local mean parameter for each client device, and ϵ is near 0.


When using a mixture model, the training of the global ML model may be expressed as:










min



{

r
j

}



j
=
1

K




1
2






j
=
1

K





r
j



2



-




i
=
1

N


log





j
=
1

K



exp

(

-






m
i

-

r
j




2


2


σ
2




)

.









where mi is the local mean parameter for each client device, {rj}j=1K are variational parameters representing the global parameter L0 and there are K networks.


When using a mixture model, the training of the local ML model may be expressed as:











min



m
i





𝔼


q
i

(

θ
i

)


[


-
log



p

(


D
i





"\[LeftBracketingBar]"


θ
i



)


]


-

log





j
=
1

K



exp

(

-






m
i

-

r
j




2


2


σ
2




)

.








Where










min



m
i




𝔼


q
i

(

θ
i

)







represents the client update optimisation, and as before mi is the local mean parameter for each client device, {rj}j=1K are variational parameters representing the global parameter L0, Di is the local dataset at each client device and θi is the local parameter for each client device and there are K networks.


In other words, there is a system for training a machine learning, ML, model using a hierarchical Bayesian approach to federated learning, the system comprising: a plurality of client devices, each client device having a set of personal training data and a local version of the ML model; and a central server for centrally training the ML model.


Each client device may locally train the local version of the ML model and transmit at least one network weight (i.e. local parameter) to the server.


The server may: link the at least one network weight to a higher-level variable (i.e. the global random variable); and train the ML model to optimise a function dependent on the weights from the client devices and a function dependent on the higher-level variable.


The client device may be a constrained-resource device, but which has the minimum hardware capabilities to use a trained neural network/ML model. The client device may be any one of: a smartphone, tablet, laptop, computer or computing device, virtual assistant device, a vehicle, an autonomous vehicle, a robot or robotic device, a robotic assistant, image capture system or device, an augmented reality system or device, a virtual reality system or device, a gaming system, an Internet of Things device, or a smart consumer device (such as a smart fridge). It will be understood that this is a non-exhaustive and non-limiting list of example client devices.


Once the global ML model has been trained, each client device can use the current global ML model as a basis for predicting an output given an input. For example, the input may be an image and the output may be a classification for the input image, e.g. for one or more objects within the image. In this example, the local dataset on each client device is a set of labelled/classified images. There are two ways the global ML model may be used: global prediction and personalised prediction.


According to an embodiment of the disclosure, there is provided a computer-implemented method for generating, using a client device, a personalised model using the global ML model which has been trained as described above. The method comprises receiving, at the client device from the server, the global parameter for the trained global ML model; optimising, at the client device, a local parameter using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter, and outputting the optimised local parameter as the personalised model. During the optimising step, sampling may be used to generate estimates for the local parameter. This method is useful when there is no personal data on the client device which can be used to train the global ML model.


According to an embodiment of the disclosure, there is provided a computer-implemented method for generating, using a client device, a personalised model using the global ML model which has been trained as described above. The method comprises receiving, at the client device from the server, the global parameter for the trained global ML model; obtaining a set of personal data; optimising, at the client device, a local parameter using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter and by applying a loss function over the set of personal data, and outputting the optimised local parameter as the personalised model. This method is useful when there is personal data on the client device which can be used to train the global ML model.


An example of optimisation during personalisation can be defined as:










min
v




𝔼

v

(
θ
)


[


-
log



p

(


D
p





"\[LeftBracketingBar]"

θ


)


]


+

KL

(


v

(
θ
)





p

(

θ




"\[LeftBracketingBar]"


ϕ
*



)



)






where Dp is the data for personalised training, ϕ* is the FL-trained server model parameters and v(θ) is the variational distribution that is optimised to approximate the personalised posterior p(θ|ϕ*).


When using the Normal-Inverse-Wishart model, an example of optimisation during personalisation can be defined as:









min
m

-

log


p

(


D
i





"\[LeftBracketingBar]"



m
~

i



)


+


p
2



(


n
0

+
d
+
1

)




(


m
i

-

m
0


)






V
0

-
1


(


m
i

-

m
0


)







where {tilde over (m)}i is the dropout version of mi the local model parameter, the global parameters are L0=(m0, V0), m0 and V0 are fixed during the optimisation and Di is the set of personal data and









v

(
θ
)

=



l



(


p
·

𝒩

(




θ
i

[
l
]

;


m
i

[
l
]


,


ϵ
2


I


)


+


(

1
-
p

)

·

𝒩

(




θ
i

[
l
]

;
0

,


ϵ
2


I


)



)








When using the mixture model, an example of optimisation during personalisation can be defined as:










min
m




𝔼

v

(
θ
)


[


-
log



p

(


D
p





"\[LeftBracketingBar]"

θ


)


]


-

log





j
=
1

K



exp

(

-





m
-

r
j




2


2


σ
2




)

.








where v(θ)=custom-character(θ; m, ϵ2I), where ϵ is a small positive constant and m are the only parameters that are learnt.


Once the model has been updated on the client device, the updated local model can be used to process data, e.g. an input. According to another aspect of the present techniques there is provided a computer-implemented method for using, at a client device, a personalised model to process data, the method comprising generating a personalised model as described above; receiving an input; and predicting, using the personalised model, an output based on the received input.


According to an embodiment of the disclosure, there is provided a computer-readable storage medium comprising instructions which, when executed by a processor, causes the processor to carry out any of the methods described herein.


As will be appreciated by one skilled in the art, the present techniques may be embodied as a system, method or computer program product. Accordingly, present techniques may take the form of an entirely hardware embodiment, an entirely software embodiment, or an embodiment combining software and hardware aspects.


According to an embodiment of the disclosure, there is provided a client device comprising a processor coupled to memory, wherein the processor is configured to: receive, from a server, a global parameter for a trained global ML model which has been trained as described above; determine whether there is a set of personal data on the client device; when there is no set of personal data, optimise a local parameter of a local ML model using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter, and when there is a set of personal data, optimise a local parameter of a local ML model using the received global parameter by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter and by applying a loss function over the set of personal data; outputting the optimised local parameter as a personalised model; and predicting, using the personalised model, an output based on a newly received input.


Furthermore, the present techniques may take the form of a computer program product embodied in a computer readable medium having computer readable program code embodied thereon. The computer readable medium may be a computer readable signal medium or a computer readable storage medium. A computer readable medium may be, for example, but is not limited to, an electronic, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any suitable combination of the foregoing.


Computer program code for carrying out operations of the present techniques may be written in any combination of one or more programming languages, including object oriented programming languages and conventional procedural programming languages. Code components may be embodied as procedures, methods or the like, and may comprise sub-components which may take the form of instructions or sequences of instructions at any of the levels of abstraction, from the direct machine instructions of a native instruction set to high-level compiled or interpreted language constructs.


Embodiments of the present techniques also provide a non-transitory data carrier carrying code which, when implemented on a processor, causes the processor to carry out any of the methods described herein.


The techniques further provide processor control code to implement the above-described methods, for example on a general purpose computer system or on a digital signal processor (DSP). The techniques also provide a carrier carrying processor control code to, when running, implement any of the above methods, in particular on a non-transitory data carrier. The code may be provided on a carrier such as a disk, a microprocessor, CD- or DVD-ROM, programmed memory such as non-volatile memory (e.g. Flash) or read-only memory (firmware), or on a data carrier such as an optical or electrical signal carrier. Code (and/or data) to implement embodiments of the techniques described herein may comprise source, object or executable code in a conventional programming language (interpreted or compiled) such as Python, C, or assembly code, code for setting up or controlling an ASIC (Application Specific Integrated Circuit) or FPGA (Field Programmable Gate Array), or code for a hardware description language such as Verilog (RTM) or VHDL (Very high speed integrated circuit Hardware Description Language). As the skilled person will appreciate, such code and/or data may be distributed between a plurality of coupled components in communication with one another. The techniques may comprise a controller which includes a microprocessor, working memory and program memory coupled to one or more of the components of the system.


It will also be clear to one of skill in the art that all or part of a logical method according to embodiments of the present techniques may suitably be embodied in a logic apparatus comprising logic elements to perform the steps of the above-described methods, and that such logic elements may comprise components such as logic gates in, for example a programmable logic array or application-specific integrated circuit. Such a logic arrangement may further be embodied in enabling elements for temporarily or permanently establishing logic structures in such an array or circuit using, for example, a virtual hardware descriptor language, which may be stored and transmitted using fixed or transmittable carrier media.


In an embodiment, the present techniques may be realised in the form of a data carrier having functional data thereon, said functional data comprising functional computer data structures to, when loaded into a computer system or network and operated upon thereby, enable said computer system to perform all the steps of the above-described method.


The method described above may be wholly or partly performed on an apparatus, i.e. an electronic device, using a machine learning or artificial intelligence model. The model may be processed by an artificial intelligence-dedicated processor designed in a hardware structure specified for artificial intelligence model processing. The artificial intelligence model may be obtained by training. Here, “obtained by training” means that a predefined operation rule or artificial intelligence model configured to perform a desired feature (or purpose) is obtained by training a basic artificial intelligence model with multiple pieces of training data by a training algorithm. The artificial intelligence model may include a plurality of neural network layers. Each of the plurality of neural network layers includes a plurality of weight values and performs neural network computation by computation between a result of computation by a previous layer and the plurality of weight values.


As mentioned above, the present techniques may be implemented using an AI model. A function associated with AI may be performed through the non-volatile memory, the volatile memory, and the processor. The processor may include one or a plurality of processors. At this time, one or a plurality of processors may be a general purpose processor, such as a central processing unit (CPU), an application processor (AP), or the like, a graphics-only processing unit such as a graphics processing unit (GPU), a visual processing unit (VPU), and/or an AI-dedicated processor such as a neural processing unit (NPU). The one or a plurality of processors control the processing of the input data in accordance with a predefined operating rule or artificial intelligence (AI) model stored in the non-volatile memory and the volatile memory. The predefined operating rule or artificial intelligence model is provided through training or learning. Here, being provided through learning means that, by applying a learning algorithm to a plurality of learning data, a predefined operating rule or AI model of a desired characteristic is made. The learning may be performed in a device itself in which AI according to an embodiment is performed, and/o may be implemented through a separate server/system.


The AI model may consist of a plurality of neural network layers. Each layer has a plurality of weight values, and performs a layer operation through calculation of a previous layer and an operation of a plurality of weights. Examples of neural networks include, but are not limited to, convolutional neural network (CNN), deep neural network (DNN), recurrent neural network (RNN), restricted Boltzmann Machine (RBM), deep belief network (DBN), bidirectional recurrent deep neural network (BRDNN), generative adversarial networks (GAN), and deep Q-networks.


The learning algorithm is a method for training a predetermined target device (for example, a robot) using a plurality of learning data to cause, allow, or control the target device to make a determination or prediction. Examples of learning algorithms include, but are not limited to, supervised learning, unsupervised learning, semi-supervised learning, or reinforcement learning.





BRIEF DESCRIPTION OF THE DRAWINGS

Implementations of the present techniques will now be described, by way of example only, with reference to the accompanying drawings, in which:



FIG. 1a is a schematic diagram of an independent and individually distributed client device, according to an embodiment of this disclosure;



FIG. 1b is a variation of the diagram of FIG. 1a showing the client data and the input and output, according to an embodiment of this disclosure;



FIG. 1c is a schematic diagram showing a plurality of client devices of FIG. 1a and illustrating global prediction, according to an embodiment of this disclosure;



FIG. 1d is a schematic diagram showing a plurality of client devices of FIG. 1a and illustrating personalised prediction, according to an embodiment of this disclosure;



FIG. 2a is an example algorithm which is suitable for implementing the general training framework, according to an embodiment of this disclosure;



FIG. 2b is a flowchart showing steps carried out by the client device(s) and the server in each round implementing the general training framework, according to an embodiment of this disclosure;



FIG. 2c is a schematic representation of the update steps carried out by each client device during the method of FIG. 2b;



FIG. 2d is a schematic representation of the update steps carried out by the server during the method of FIG. 2b;



FIG. 2e is an example algorithm which is suitable for implementing the general training framework of FIG. 2a using block-coordinate descent, according to an embodiment of this disclosure;



FIG. 2f is an example algorithm which is suitable for implementing the general global prediction framework, according to an embodiment of this disclosure;



FIG. 2g is a flowchart showing steps carried out by the client device(s) and the server in the general global prediction framework such as that shown in FIG. 2d;



FIG. 2h is a schematic representation of the method of FIGS. 2f and 2g;



FIG. 2i is an example algorithm which is suitable for implementing the personalisation framework, according to an embodiment of this disclosure;



FIG. 2j is a flowchart showing steps carried out by the client device(s) and the server in the general personalisation framework such as that shown in FIG. 2d;



FIG. 2k is a schematic representation of the method of FIGS. 2i and 2j;



FIG. 3a is an example of pseudo code which is suitable for implementing a training algorithm using the normal-inverse-Wishart case, according to an embodiment of this disclosure;



FIG. 3b illustrates schematically how global prediction may be applied using the normal-inverse-Wishart case, according to an embodiment of this disclosure;



FIG. 3c is example pseudo code to implement global prediction using the normal-inverse-Wishart case, according to an embodiment of this disclosure;



FIG. 3d illustrates schematically how personalisation may be applied using the normal-inverse-Wishart case, according to an embodiment of this disclosure;



FIG. 3e is example pseudo code to implement personalisation using the normal-inverse-Wishart case, according to an embodiment of this disclosure;



FIG. 4a is an example of pseudo code which is suitable for implementing a training algorithm using the mixture case, according to an embodiment of this disclosure;



FIG. 4b is example pseudo code to implement global prediction using the mixture case, according to an embodiment of this disclosure;



FIG. 4c is example pseudo code to implement personalisation using the mixture case, according to an embodiment of this disclosure;



FIGS. 5a and 5b are tables showing results of experiments testing global prediction performance (initial accuracy) and personalisation performance on CIFAR 100;



FIGS. 5c and 5d are tables showing results of experiments testing global prediction performance (initial accuracy) and personalisation performance on MNIST/FMNIST/ENMIST;



FIG. 5e comprises tables showing results of experiments testing global prediction performance (initial accuracy) and personalisation performance on CIFAR-C-100;



FIGS. 6a and 6b plots the accuracy against dropout probability for global prediction and personalisation for both a known technique and the current technique respectively;



FIGS. 6c and 6d plots the accuracy against number of networks for global prediction and personalisation for three known techniques and the current technique respectively;



FIG. 6e includes tables showing the global prediction and personalisation accuracy on CIFAR-100;



FIG. 6f shows tables showing the global prediction and personalisation accuracy on CIFAR-C-100;



FIGS. 7a to 7d are tables comparing the complexity and running times of the proposed algorithms with FedAvg, and



FIG. 8 is a block diagram of a system for federated learning, according to an embodiment of this disclosure.





DETAILED DESCRIPTION

Broadly speaking, the present disclosure provides a method for training a machine learning, ML, model to update global and local versions of a model without the server having to access user data.


The two most popular existing federated learning, FL, algorithms are Fed-Avg which is described for example in “Communication-Efficient Learning of Deep Networks from Decentralized Data” by McMahan et al published in AI and Statistics in 2017 and Fed-Prox which is described in “Federated Optimization in Hetergeneous Networks” by Li et al published in arXiv. Their learning algorithms are quite simple and intuitive: repeat several rounds of local update and aggregation. At each round, the server maintains a global model θ and distribute it to all clients. Then clients update the model (initially the server-sent model) with their own data (local update) and upload the updated local models to the server. Then the server takes the average of the clients' local models which becomes a new global model (aggregation). During the local update stage, Fed-Prox imposes an additional regularisation to enforce the updated model to be close to the global model.


Several attempts have been made to model the FL problem from a Bayesian perspective.


Introducing distributions on model parameters θ has enabled various schemes for estimating a global model posterior p(θ|D1:N) from clients' local posteriors p(θ|Di), or to regularise the learning of local models given a prior defined by the global model. Although some recent FL algorithm aim to approach the FL problem by Bayesian methods, they are not fully satisfactory to be interpreted as a principled Bayesian model, and often resort to ad-hoc treatments. The key difference between our approach and these previous methods is: They treat network weights θ as a random variable shared across all clients, while our approach assigns individual θi to each client i and link the random variables θi's via another higher-level variable ϕ. That is, what is introduced is a hierarchical Bayesian model that assigns each client its own random variable θi for model weights, and these are linked via a higher level random variable ϕ as p(θ1:N, ϕ)=p(ϕ)πi=1Np(θi|ϕ). This has several crucial benefits: Firstly, given this hierarchy, variational inference in our framework decomposes into separable optimisation problems over θis and ϕ, enabling a practical Bayesian learning algorithm to be derived that is fully compatible with FL constraints, without resorting to ad-hoc treatments or strong assumptions. Secondly, this framework can be instantiated with different assumptions on p(θi|ϕ) to deal elegantly and robustly with different kinds of statistical heterogeneity, as well as for principled and effective model personalisation. The main drawback of the shared θ modeling is that solving the variational inference problem for approximating the posterior p(θ|D1:N) is usually not decomposed into separable optimisation over individual clients, thus easily violating the FL constraints. To remedy this issue, either strong assumptions have to be made or r ad hoc strategies have to be employed to perform client-wise optimisation with aggregation.


We propose a novel hierarchical Bayesian approach to Federated Learning (FL), where our models reasonably describe the generative process of clients' local data via hierarchical Bayesian modeling: constituting random variables of local models for clients that are governed by a higher-level global variate. Interestingly, the variational inference in our Bayesian model leads to an optimisation problem whose block-coordinate descent solution becomes a distributed algorithm that is separable over clients and allows them not to reveal their own private data at all, thus fully compatible with FL. Beyond introducing novel modeling and derivations, we also offer convergence analysis showing that our block-coordinate FL algorithm converges to an (local) optimum of the objective at the rate of O(1/√{square root over (t)}), the same rate as regular (centralised) SGD, as well as the generalisation error analysis where we prove that the test error of our model on unseen data is guaranteed to vanish as we increase the training data size, thus asymptotically optimal.


The hierarchical Bayesian models (NIW and Mixture—explained in more detail below) are a canonical formalisation for modeling hetergenous data, including personalisation. They offer a principled way to decompose shared (global) and local (personalised) knowledge and to learn both jointly. By making specific choices about the distributions involved (as, hierarchical Bayesian models can be explicitly configured to model different types of data heterogeneity. For example, when users group into cluster, the mixture model provides a good solution. The kind of transparent mapping between the algorithm and the nature of the data heterogeneity is not provided by other non-hierarchical methods.


Bayesian FL: General Framework


FIGS. 1a to 1d show graphical models of the general framework for Bayesian Federated Learning. Specifically, Figure 1a shows a plate view of independent and individually distributed (iid) client devices 112 each of which have a dataset Di and an local or individual random variable θi. As explained below these local random variables are governed by a single random variable ϕ. FIG. 1b shows individual client data D with input images x given and only p(y|x) modeled. FIG. 1c shows global prediction at a client device 112 where x* is the test input and y* is the output to be inferred using the θ be the local model as a probalistic inference problem. FIG. 1d shows personalisation at a client device 112 using personal data Dp before inference of the output yp using the test input xp. In each of FIGS. 1c and 1d, the shaded nodes represent evidence, e.g. the individual data Di on a client device or the test input x*.



FIGS. 1a to 1d show two types of latent random variables, ϕ and {θi}i−1N, where a latent variable is defined as a variable which is inferred from observed data using models. The higher-level variable ϕ determines the prior distribution of the individual client model (random variables) θi. Since each clients prior distribution shares the same ϕ, this parameter places the role of the shared knowledge across clients. It is exactly this principled decomposition of a global shared variable and local client-specific random variables that makes the hierarchical Bayesian framework a principled and empirically effective solution to both global model learning and personalised model learning.


Typically, each θi, one for each local client i=1, . . . , N, will be deployed as the network parameters to client i's backbone. The variable ϕ can be viewed as a globally shared variable that is responsible for linking the individual client parameters θi. In our modeling, we assume conditionally independent and identical priors, that is,






p1, . . . , θN|ϕ)=πi=1Npi|ϕ)   (1)


where p(θ1|ϕ) shares the same conditional distribution p(θ|ϕ). Thus the prior for the latent, variables (ϕ, {θi}i=1N) is formed in a hierarchical manner as:






p(ϕ, θ1, . . . , θN)=p(ϕ)πi=1Npi|ϕ),   (2)


So the prior distribution, which may be defined as a prior probability distribution of an uncertain quantity is the assumed probability distribution before some evidence is taken into account. The terms prior distribution and prior may be interchanged. The prior for each latent variable is fully specified by p(ϕ) and p(θ|ϕ).


The local data for client i, denoted by Di is determined by the local client θi where the likelihood is:






p(Dii)=π(x,y)ϵDip(y|x, θi),   (3)


where p(y|x, θi) is a conventional neural network likelihood model (e.g., softmax likelihood/link after a neural network feed-forward for classification tasks). Note that as per our definition (3) we do not deal with generative modeling of input images x, that is, input images are always given with only conditionals p(y|x) modeled. FIG. 1b illustrates this graphically for each individual client.


Given the local training data D1, . . . , DN, we can in principle infer the posterior distribution of the latent variables. The posterior distribution may be defined as a type of conditional probability that results from updating the prior probability with information summarized by the likelihood via an application of Bayes' rule. The posterior distribution which is the conditional distribution of the latent variables given the local data may be written as:






p(ϕ, σ1:N)∝p(ϕ)πi=1Npi|ϕ)p(Diϕθi )   (4)


In other words, the posterior distribution of ϕ, θ1:N given all the datasets is proportional to the product of the prior distribution for ϕ and each of the distributions for θi given ϕ and the distributions of each dataset given θi. However, the posterior distribution is intractable in general, and we need to approximate the posterior inference. We adopt the variational inference, approximating the posterior (4) by a tractable density q(ϕ, θ1, . . . θN; L), parameterized by L. We specifically consider a fully factorized density over all variables, that is,






q(ϕ, θ1, . . . , θN; L):=q(ϕ; L0i−1Nqii; Li),   (5)


where the variational parameters L consists of L0 (parameters for q(ϕ)) and {Li}i=1N's (parameters for qii)′s from individual clients). Note that although θi's are independent across clients under (5), they are differently modelled (emphasized by the subscript i in notation qi), reflecting different posterior beliefs originating from different/heterogeneous local data Di's. We will show below in that this factorized variational density leads to fully separable block-coordinate ascent ELBO optimization which allows us to optimize q(ϕ) and qii)'s without accessing the local data from other parties, leading to viable federated learning algorithms.


The main motivations of our hierarchical Bayesian modeling are two fold: i) Introducing client-wise different model parameters θi provides a way to deal with non-iid heterogeneous client data, as reflected in the posteriors qii), while we still take into account the shared knowledge during the posterior inference through the shared prior p(θi|ϕ). Ii) As will be discussed in the next section, it enables a more principled learning algorithm by separating the two types of variables ϕ (shared) and θi (local).


From Variational Inference to Federated Learning Algorithm

Using the conventional/standard variational inference techniques, we can derive the ELBO objective function. The ELBO objective function may be termed the evidence lower bound objective function or the variational lower bound or negative variational free energy objective function. We denote the negative ELBO function by custom-character (to be minimized over L) as follows:






custom-character(L):=Σi=1N(custom-characterqii)[−logp(Dii)]+custom-characterq(ϕ)[KL(qii)∥pi|ϕ))])+KL(q(ϕ)∥p(ϕ))   (6)


where custom-characterqii) is the evidence lower bound function, qii) is the posterior for each client device and its parameters θi, p(Dii) is the prior for the client device data Di given parameters θi, N is the number of client devices, KL represents the Kulback-Leibler divergence, p(θi|ϕ) is the prior for the client parameters θi given ϕ, q(ϕ) is the posterior for ϕ and p(ϕ) is the prior for ϕ. In (6), we drop the dependency on L in notation for simplicity. That is, q(ϕ) and qii) indicate q(ϕ; L0); L0) and qii; Li), respectively.


The equation (6) could be optimised over the parameters (L0), {Li}) i.e. over L jointly using centralised learning. However, as described below, we consider block-wise optimization, also known as block-coordinate optimization, specifically alternating two steps: (i) updating/optimizing all Li's i=1, . . . , N while fixing L0, and (ii) updating L0 with all Li's fixed. That is, the objective functions for the two steps are as follows:


Optimization over L1, . . . , LN (L0 fixed).












Min


{

L
i

}


i
=
1

N









i
=
1

N




(



𝔼


q
i

(

θ
i

)


[


-
log



p

(



D


i





"\[LeftBracketingBar]"


θ
i



)


]

+


𝔼

q

(
ϕ
)


[

KL

(



q
i

(

θ
i

)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)



)

]


)

.





(
7
)








In other words, in this step, the final term KL(q(ϕ)∥p(ϕ) of equation (6) may be considered to be deleted.


The objective function in (7) is completely separable over i, and we can optimize each summand independently as:













min

L
i






i

(

L
i

)


:=



𝔼


q
i

(


θ
i

;

L
i


)


[


-
log



p

(


D
i





"\[LeftBracketingBar]"


θ
i



)


]

+



𝔼

q

(

ϕ
;

L
0


)


[

KL

(



q
i

(


θ
i

;

L
i


)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)



)

]

.






(
8
)








So (8) constitutes local update/optimization for client i. Note that each client i needs to access its private data Di only, without data from others, thus this approach is fully compatible with FL. Once the first step of optimisation has been done, we can fix Li to do


Optimization over L0 (L1, . . . , LN fixed).













Min

L
0






0

(

L
0

)


:=


KL

(


q

(

ϕ
;

L
0


)





p

(
ϕ
)



)

-







i
=
1

N





𝔼


q

(

ϕ
;

L
0


)




q
i

(


θ
i

;

L
i


)



[

log


p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)


]

.







(
9
)








In other words, in this step, the first term custom-characterqii)[−logp(Dii)] of equation (6) may be considered to be deleted. This amounts to server update/optimization criteria with the latest updates qii; Li)'s from local clients being fixed. Remarkably, the server needs not access any local data at all, suitable for FL. This nice property originates from the independence assumption in our approximate posterior (5).


Interpretation. First, the server's loss function (9) tells us that the server needs to update the posterior q(ϕ; L0) in such a way that (i) it puts mass on those ϕ that have high compatibility scores logp(θi|ϕ) with the current local models θi˜qii) for i=1, . . . , n, thus aiming to be aligned with local models, and (ii) it does not deviate much from the prior p(ϕ). Now, clients' loss function (8) indicates that each client i needs to minimize the class prediction error on its own data Di (first term), and at the same time, to stay close to the current global standard ϕ˜q(ϕ) by reducing the KL divergence from p(θi|ϕ) (second term).



FIG. 2a illustrates an algorithm which is suitable for implementing the general framework above. FIG. 2b is a flowchart showing the key steps in each round r=1, 2, . . . , R of the method. In a first step S200 of each round r, the server selects a predetermined number Nf(≤N) of client devices to participate in the round. In a next step S202, the server sends the parameters L0 of the global posterior q(ϕ; L0), to the participating client devices. These parameters are received at each client device at step S204.


In a next step S206, each local client device updates its local model qii; Li) using an appropriate technique. For example, the optimization in equation (8) above may be applied, e.g.










min

L
i





E


q
i

(


θ
i

;

L
i


)


[


-
log



p

(


D
i





"\[LeftBracketingBar]"


θ
i



)


]


+


E

q

(

ϕ
;

L
0


)


[

KL
(



q
i

(


θ
i

;

L
i


)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)




]






It is noted that during this optimisation, L0 is fixed. The first part of the optimization Eqii;Li) is to fit data on the client device data. The second part of the optimisation Eq(ϕ:L0) uses a regularization to make sure that the updated θi stay close to the current global ϕ. This second part may be termed a regularisation or penalty term and prevents the client device update from deviating far away from the current global estimate for the model. The optimisation may be solved using stochastic gradient descent (SGD). Once the updating has been done at step S208, each local device I sends the updated local posterior parameters Li to the server.



FIG. 2c is a schematic representation of steps S204 and S206. There are a plurality of client devices 212a, 212i, 212N each of which has their own local data Di and local model parameters Li. The client devices are connected to a server 202 which has a global model with global model parameters L0. As shown, the global model parameters L0 are sent to at least one client device 212i. The received global parameters are used to optimise the local model parameters Li using an optimisation which has two components. One component is a regulariser which receives as inputs the local model parameters Li and the global model parameters L0 and is used to penalise discrepancy between these two inputs during the optimisation. The other component receives as inputs, local model parameters Li and the local data set Di and during the optimisation fits the local model parameters Li to the local data set, e.g. using a loss function.


The server receives each of the updated local posterior parameters Li from the client devices at step S210. The server then updates at step S212 the global posterior q(ϕ; L0) by a suitable technique. For example, the optimization in equation (9) above may be applied, e.g.










min

L
0








i

N
f





E

q

(

ϕ
;

L
0


)


[

KL
(



q
i

(


θ
i

;

L
i


)





p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)




]


+

KL

(


q

(

ϕ
;

L
0


)





p

(
ϕ
)



)






It is noted that during this optimisation, each Li is fixed. The first part Eq(ϕ;L0) considers the consensus over the client updates and in other words considers the compatibility between the global ϕ and the local models θi. The second part is a regularization term to prevents the server update to L0 from moving too far from the prior knowledge/information. In this way, any abrupt update cause by a spurious client update can potentially be avoided. The optimisation may be solved using stochastic gradient descent (SGD).



FIG. 2d is a schematic representation of steps S210 and S212. As in FIG. 2c, there are a plurality of client devices 212a, 212i, 212N each of which has their own local data Di and local model parameters Li. The client devices are connected to the server 202 which has a global model with global model parameters L0. As shown, the local model parameters L1, Li, . . . , LN are sent to the server 202. The received local model parameters are used to optimise the global parameters L0 using an optimisation which has two components. One component is a regulariser which receives as inputs the prior knowledge (I) and the global model parameters L0 and is used to penalise discrepancy between these two inputs during the optimisation. The other component is a sum of separate regularisers for each of the client devices and each regulariser prevents the updated global model parameters L0 from drifting too far from each of the local model parameters Li.


The round is then complete, and the next round can begin with the random selection at step S200 and repeat all other steps. Once all rounds are complete, the trained parameters L0 are output. FIG. 2e is example pseudo-code which gives more detail of the overall Bayesian Federated Learning Algorithm as Block-Coordinate descent and in particular uses SGD as described above. In summary, the method above tackles communication constraints and privacy in federated learning by developing a hierarchical Bayesian model for which variational inference with block-coordinate descent naturally decomposes over the client device and server posteriors (e.g. target parameters L0 and Li to be learned.


Formalisation of Global Prediction and Personalisation Tasks

Returning to FIGS. 1c and 1d, as illustrated, there are two important tasks in FL: global prediction and personalisation. Global prediction is to judge how well the trained FL model performs on novel test data points sampled from a domain/distribution which is possibly different from that of training data. Personalisation is the task of adapting the trained FL model on a specific dataset called personalised data. Specifically, we are given a training split of the personalized data to update the FL-trained model. Then we measure the performance of the adapted model on the test split that conforms to the same distribution as the training split.


In the existing non-Bayesian FL approaches, these tasks are mostly handled straightforwardly since we have a single point estimate of the global model obtained from training. For global prediction, they just feed the test points forward through the global model; for personalisation, they usually finetune the global model with the personalised training data, and test it on the test split. Thus, previous FL approaches may have issues in dealing with non-iid client data in a principled manner, often resorting to ad-hoc treatments. In our Bayesian treatment/model, these two tasks can be formally defined as Bayesian inference problems in a more principled way. Our hierarchical framework introduces client-wise different model parameters θi to deal with non-iid heterogeneous client data more flexibly, reflected in the different client-wise posteriors qii ).


Global prediction. The task is to predict the class label of a novel test input x* which may or may not originate from the same distributions/domains as the training data D1, . . . , DN. It can be turned into a probabilistic inference problem p(y*|x*, D1, . . . , DN). Under our Bayesian model, we let θ be the local model that generates the output y* given x*. See FIG. 1c for the graphical model diagram. Exploiting conditional independence, we can derive the predictive probability as follows from FIG. 1c:
















p


(


y
*





"\[LeftBracketingBar]"



x
*

,

D
1

,


,

D
N






)

=





p


(


y
*





"\[LeftBracketingBar]"



x
*

,
θ



)








p

(

θ




"\[LeftBracketingBar]"

ϕ


)




p

(

ϕ




"\[LeftBracketingBar]"



D
1

,


,

D
N




)




d

θ

d

ϕ







(
10
)
























p
(


y
*





"\[LeftBracketingBar]"



x
*

,
θ



)







p

(

θ




"\[LeftBracketingBar]"

ϕ


)




q

(
ϕ
)




d

θ

d

ϕ







(
11
)



















=



p
(


y
*





"\[LeftBracketingBar]"



x
*

,
θ



)






(



p

(

θ




"\[LeftBracketingBar]"

ϕ


)








q

(
ϕ
)


d

ϕ

)




d

θ





,




(
12
)








where in (11) we use our variational approximate posterior q(ϕ). In our specific model choice of Normal-Inverse-Wishart (see below for more details), the inner integral in (12) can be succinctly written as a closed form (multivariate Student-t). Alternatively, the inner integral can be approximated (e.g. using a Monte-Carlo estimate).



FIG. 2f gives an example pseudo code which is suitable for implementing the general global prediction framework. FIG. 2g is a flowchart summarizing the steps of FIG. 2f In a first step S222, the client device receives an input x*, for example an input image to be classified or edited (for example altered to include the classification). The client device also obtains the parameters for the global model at step S224. The current global model could be obtained from the server as shown at step S226, the server sends the parameters for the global model to the participating client(s). As explained above, the learned model L0 is used in the variational posterior q(ϕ, L0). This global model is received from the server at step S228 at the client device. These steps to obtain the global model could be done before receiving the input or could be triggered by receipt of the input.


Before using the received global model, the local device then personalizes the received global model at step S230. One way as shown in the pseudocode of FIG. 2f is to use Monte Carlo sampling with S being the number of Monte Carlo samples). At step S232, the client device infers the output y* (e.g. a class label) using the updated personalised model. For example, when using sampling to personalize the model, the output may be generated using:









p

(


y
*





"\[LeftBracketingBar]"



x
*

,

D
1

,


,

D
N




)




1
S






s
=
1

S


p
(


y
*





"\[LeftBracketingBar]"



x
*

,

θ

(
s
)





)








where θ(s)≈∫p(θ|ϕ) q(ϕ; L0) dϕ.


In a final step S230, there is an output. The output could be for example a class label, an edited image, e.g. an image which has been edited to include the class label or to otherwise alter the image based on the class label.



FIG. 2h is a schematic representation of the method of FIGS. 2f and 2g. The server 202 sends the parameters for the global model L0 to the participating client device 212p. An updated local model with local model parameters Lp is generated at the client device 212p. As indicated, one technique for doing this when there is no local client data is to optimising the global model parameters using sampling. The optimisation includes a regulariser (regularization penalty) similar to the ones which are used during the training framework. In this case, the regulariser penalises the discrepancy between the “guess” for the local model parameters Lp and the received global model parameters L0. Once the local model is updated, it can then be used to predict the output y based on the input x.


Personalisation formally refers to the problem of learning a prediction model {circumflex over (p)}(y|x) given an unseen (personal) training dataset Dp that comes from some unknown distribution pp(x,y), so that the personalised model {circumflex over (p)} performs well on novel (in-distribution) test points (xp, yp) ˜pp(x, y). Evidently we need to exploit (and benefit from) the model that we trained during the federated learning stage. To this end many existing approaches simply resort to finetuning, that is, training the network on DP with the FL-trained model as an initial iterate. However, a potential issue with finetuning is the lack of a solid principle on how to balance the initial FL-trained model and personal data fitting to avoid both underfitting the parameters for the global model to the participating client(s). and overfitting. FIG. 2i gives an example pseudo code which is suitable for implementing the general personalisation framework. FIG. 2j is a flowchart summarizing the steps of FIG. 2i. In our Bayesian framework, the personalisation can be seen as another posterior inference problem with additional evidence given by the personal training data Dp.


In a first step S240, this personal training data DP is obtained by any suitable technique, e.g. by separating a portion of the client data which is not used in the general training described above. In a next step S242 (which could be simultaneous with or before the previous steps), the client device receives an input xp, for example an input image to be classified or edited (for example altered to include the classification). As explained above in the global prediction, the client device then obtains the global model parameters, for example as shown in step S244, the server sends the parameters for the global model to the participating client(s). As explained above, the learned model L0 is used in the variational posterior q(ϕ, L0). This global model is received from the server at step S246 at the client device.


Then the prediction on a test point xp amounts to inferring the posterior predictive distribution,






p(yp|xp, DP, D1, . . . , DN)=∫p(yp|xpθ) p(θ|Dp, D1, . . . , DN)dθ.   (13)


So, it boils down to the task of posterior inference p(θ|Dp, D1, . . . , DN) given both the personal data Dp and the FL training data D1, . . . , DN. Under our hierarchical model, by exploiting conditional independence from the graphical model shown in FIG. 1d, we can link the posterior to our FL-trained q(ϕ) as follows:










p

(

θ




"\[LeftBracketingBar]"



D
p

,

D
1

,


,


D
N




)

=




p

(

θ




"\[LeftBracketingBar]"



D
p

,
ϕ



)



p

(

ϕ




"\[LeftBracketingBar]"



D
1

,


,

D
N




)


d

ϕ






(
14
)

















p

(

θ




"\[LeftBracketingBar]"



D
p

,
ϕ



)



q

(
ϕ
)


d

ϕ






(
15
)















p

(

θ




"\[LeftBracketingBar]"



D
p

,

ϕ
*




)


,




(
16
)







where in (14) we disregard the impact of Dp on the higher-level given the joint evidence, p(ϕ|Dp, D1, . . . , DN)≈p(ϕ|D1, . . . , DN) due to the dominance of D1:N compared to smaller Dp . In (16) we approximate the integral by mode evaluation at the mode ϕ* of q(ϕ), which can be reasonable for spiky q(ϕ) in our two modeling choices to be discussed below. Since dealing with p(θ|Dp, ϕ*) involves the difficult marginalisation p(Dp|ϕ*)=∫p(Dp|θ)p(θ|ϕ*)dθ, we adopt variational inference, introducing a tractable variational distribution v(θ)≈p(θ|Dp, ϕ*). Following the usual variational inference (VI) derivations, we have the negative ELBO objective function (for personalisation) as follows:











min
v



𝔼

v

(
θ
)


[


-
log



p

(


D
p





"\[LeftBracketingBar]"

θ


)


]


+

K



L

(


v

(
θ
)





"\[LeftBracketingBar]"



"\[RightBracketingBar]"




p

(

θ




"\[LeftBracketingBar]"


ϕ
*



)


)

.






(
17
)







Thus, at step S248, we can personalize the global model, using the optimisation above. Once we have the optimised model v, at step S250, we infer an output. One way as shown in the pseudocode of FIG. 2i is to use Monte Carlo sampling with S being the number of Monte Carlo samples). Our predictive distribution becomes:











p

(


y
p





"\[LeftBracketingBar]"



x
p

,

D
p

,

D
1

,


,

D
N




)




1
S






s
=
1

S


p

(


y
p





"\[LeftBracketingBar]"



x
p

,

θ

(
s
)





)




,




(
18
)











where



θ

(
s
)





v

(
θ
)


,




which simply requires feed-forwarding test input xp through the sampled networks θ(s) and averaging.


In a final step S252, there is an output. The output could be for example a class label, an edited image, e.g. to include the class label or to otherwise alter the image based on the class label.



FIG. 2k is a schematic representation of the method of FIGS. 2i and 2j. The server 202 sends the parameters for the global model L0 to the participating client device 212p. An updated local model with local model parameters Lp is generated at the client device 212p. In this case, there is extra personal data Dp which can be used to personalize the model. The optimisation includes a regulariser (regularization penalty) similar to the ones which are used during the training framework. In this case, the regulariser penalises the discrepancy between the “guess” for the local model parameters Lp and the received global model parameters L0. The optimisation also includes a term (e.g. a loss function) which fits the local model parameters Lp to the extra personal data Dp, for example using the tractable variational distribution v(θ) as described above. Once the local model is updated, it can then be used to predict the output y based on the input x.


Thus far, we have discussed a general framework for our Bayesian FL, deriving how the variational inference for our general Bayesian model fits gracefully in the FL framework. The next step is to define specific distribution families for the priors (p(ϕ) , p(θi|ϕ)) and posteriors (q(ϕ), q(θi). We propose two different model choices that we find the most interesting:


Normal-Inverse-Wishart (NIW) model: Good for general models, admits close forms in most cases, computationally no extra cost required.


Mixture model: Good for more drastic distribution/domain shift, heterogeneity, non-iid data.


Normal-Inverse-Wishart Model. We define the prior as a conjugate form of Gaussian and Normal-Inverse-Wishart. More specifically, each local client has Gaussian prior p(θi|ϕ)=custom-characteri; μ, Σ) where custom-character is a multivariate normal distribution, μ is the mean and Σ is the covariance matrix , and the global latent variable ϕ is distributed as a conjugate prior which is Normal-Inverse-Wishart (NIW), with ϕ=(μ, Σ): custom-character






p(ϕ)=custom-character(μ, Σ; Λ)=custom-character(μ; μ0, λ0−1Σ)·custom-character(Σ; Σ0, v0),   (19)






pi|ϕ)=custom-characteri; μ, Σ), i=1, . . . , N,   (20)


where Λ={μ0, Σ0, λ0, v0} is the parameters of the NIW. Although λ can be learned via data marginal likelihood maximisation (e.g., empirical Bayes), but for simplicity we leave it fixed as: μ0=0, Σ0=I, λ01, and v0=d+2 where d is the number of parameters (or dimension) in θi or μ. Note that we set the degree of freedom (d.o.f) v0 for the Inverse-Wishart as the smallest integer value that leads to the least informative prior with finite mean value. This choice ensures that the mean of Σ E equals I, and μ is distributed as zero-mean Gaussian with covariance Σ.


Next, our choice of the variational density family for q(ϕ) is the NIW, not just because it is the most popular parametric family for a pair of mean vector and covariance matrix ϕ=(μ, Σ), but it can also admit closed-form expressions in the ELBO function due to the conjugacy as derived below.






q(ϕ):=custom-character(ϕ; {m0, V0, I0, n0})=custom-character(μ; m0, l0−1Σ)·custom-character(Σ; V0, n0).   (21)


where m0 is a parameter based on the mean, i.e. ρ*=m0 and V0 is a parameter based on the covariance matrix and defined by








*


=


V
0



n
0

+
d
+
2







Although the scalar parameters l0, n0 can be optimized together with m0, V0, their impact is less influential and we find that they make the ELBO optimization a little bit cumbersome. So we aim to fix l0, n0 with some near-optimal values by exploiting the conjugacy of the NIW prior-posterior under the Gaussian likelihood. For each θi, we pretend that we have instance-wise representative estimates θi(x, y), one for each (x, y)ΣDi. For instance, one can view θi(x, y) as the network parameters optimized with the single training instance (x, y). Then this amounts to observing |D| (=Σi=1N|Di|) Gaussian samples θi(x, y)˜custom-characteri; μ, Σ) for (x,y)˜Di and i=1, . . , N. Then applying the NIW conjugacy, the posterior is the NIW with l00+|D|=|D|+1 and n0=v0+|D|=|D|+d+2. This gives us good approximate estimates for the optimal l0, n0, and we fix them throughout the variational optimization. Note that this is only heuristics for estimating the scalar parameters l0, n0 quickly, and the parameters m0, V0 are determined by the principled ELBO optimizationas variational parameters L0={m0, V0}. Since the dimension d is large (the number of neural network parameters), we restrict V0 to be diagonal for computational tractability.


The density family for qii)'s can be a Gaussian, but we find that it is computationally more attractive and numerically more stable to adopt the mixture of two spiky Gaussians that leads to the MC-Dropout, for example as described in “Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning” by Gal et al published in International Conference on Machine Learning 2016. That is,






q
ii)=πl(p·custom-characteri[l]; mi[l], ϵ2I)+(1−pcustom-characteri[l]; 0, ϵ2I)),   (22)


where mi is the only variational parameter and is based on the mean (Li={mi}), ·[l] indicates the specific column/layer in neural network parameters where l goes over layers and columns of weight matrices, p is the (user-specified) hyperparameter where 1−p corresponds to the dropout probability, and ϵ is a tiny constant (e.g., 10−6) that makes two Gaussians spiky, close to the delta function. Now we provide more detailed derivations for the client optimization and server optimization.


Detailed Derivations for NIW Model


FIG. 3a is an example of pseudo code which is suitable for implementing a training algorithm using the normal-inverse-Wishart case described above. The general steps of the training algorithm are the same as FIG. 2b. In other words, in a first step S200 of each round r, the server selects a predetermined number Nf(≤N) of client devices to participate in the round. In a next step S202, the server sends the parameters L0 of the global posterior q(ϕ; L0), to the participating client devices. These parameters are received at each client device at step S204.


Client update. In a next step S206, each local client device updates its local posterior qii; Li) using an appropriate technique. In this example, we apply the general client update optimisation (8). We note that q(ϕ) is spiky since our pre-estimated NIW parameters l0 and n0 are large (as the entire training data size |D| is added to the initial prior parameters). Due to the spiky q(ϕ), we can accurately approximate the second term in (8) as:






custom-character
q(ϕ)
[KL(qii )∥pi|ϕ))]≈KL(qii)≈pi|ϕ*)),   (23)


where ϕ*=(μ*,Σ*) is the mode of q(ϕ), which has closed forms for the NIW distribution:











μ
*

=

m
0


,


Σ
*

=



V
0



n
0

+
d
+
1


.






(
24
)







In (23) we have the KL divergence between a mixture of Gaussians (22) and a Gaussian (20). We apply the approximation KL(Σiαicustom-charactercustom-character≈ΣiαiKL KL(custom-charactercustom-character) as well as the reparameterised sampling for (22), which allows us to rewrite (8) as:













min

m
i





i

(

m
i

)


:

=



-
log



p

(


D
i





"\[LeftBracketingBar]"



m
˜

i



)


+


p
2



(


n
0

+
d
+
1

)




(


m
i

-

m
0


)

T




V
0

-
1


(


m
i

-

m
0


)




,




(
25
)







where {tilde over (m)}i is the dropout version of mi, i.e., a reparameterized sample from (22). This optimisation in (25) is then solved at step S206 to update custom-characteri for each client device. Also, we use a minibatch version of the first term for a tractable SGD update, which amounts to replacing the first term by the batch average custom-character(x,y)˜Batch[−logp(y|x, {tilde over (m)}i)] while downweighing the second term by the factor of 1/|Di|. Note that m0 and V0 are fixed during the optimisation. Interestingly (25) generalises the famous Fed-Avg and Fed-Prox: With p=1 (i.e., no dropout) and setting V0=αI for some constant α, we see that (25) reduces to the client update formula for Fed-Prox where constant α controls the impact of the proximal term. In a next step S208, each local client device sends its updated local parameters custom-characteri=mi to the server.


Server update. The server receives each of the updated local posterior parameters Li from the client devices at step S210. The general server optimisation (9) involves two terms, both of which admit closed-form expressions thanks to the conjugacy. Furthermore, we show that the optimal solution (m0, V0) of (9) has an analytic form. First, the KL term in (9) is decomposed as:






KL(custom-character(Σ; V0, n0)|custom-character(Σ; Σ0, v0))+custom-character[KL(custom-character(μ; m0, l0−1Σ)∥custom-character(μ; m0, λ0−1Σ))]  (26)


By some algebra, (26) becomes identical to the following, up to a constant, removing those terms that are not dependent on m0, V0:





½(n0Tr0V0−1)+v0log|V0|+λ0n00−m0)TV0−00−m0)).   (27)


Next, the second term of (9) also admits a closed form as follows:










-


𝔼


q

(
ϕ
)




q
i

(

θ
i

)



[

log


p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)


]


=




n
0

2



(


p


m
i




V
0

-
1




m
i


-

p


m
0




V
0

-
1




m
i


-

p


m
i




V
0

-
1




m
0


+


m
0




V
0

-
1




m
0


+


1

n
0



log




"\[LeftBracketingBar]"


V
0



"\[RightBracketingBar]"



+


ϵ
2



Tr

(

V
0

-
1


)



)


+

const
.






(
28
)







That is, server's loss function custom-character0 is the sum of (27) and (28). We can take the gradients of the loss with respect to m0, V0 as follows (also plugging μ0=0, Σ0=I, λ0=1, v0 =d+2):

















0





m
0



=


n
0




V
0

-
1


(



(

N
+
1

)



m
0


-

p





i
=
1

N


m
i




)



,





(
29
)


















0





V
0

-
1




=


1
2



(




n
0

(

1
+

N


ϵ
2



)


I

-


(

N
+
d
+
2

)



V
0


+


n
0



m
0



m
0



+


n
0






i
=
1

N


ρ

(


m
0

,

m
i

,
p

)




)



,




(
30
)












where



ρ

(


m
0

,


m
i


p


)


=


p


m
i



m
i



-

p


m
0



m
i



-

p


m
i



m
0



+


m
0




m
0


.








We set the gradients to zero and solve for them, which yields the optimal solution:











m
0
*

=


p

N
+
1







i
=
1

N


m
i




,




(
31
)













V
0
*

=



n
0


N
+
d
+
2





(



(

1
+

N


ϵ
2



)


I

+



m
0
*

(

m
0
*

)



+




i
=
1

N


ρ

(


m
0
*

,

m
i

,
p

)



)

.






(
32
)







Note that mi's are fixed from clients' latest variational parameters. This optimisation in (31) and (32) is then solved at step S212 to update custom-character0 at the server.


Since Dpρ(m*0, mi, p=1)=(mi−m*0)(mi−m*0)T when p=1, we can see that V*0 in (32) essentially estimates the sample scatter matrix with (N+1) samples, namely clients' mi's and server's prior μ0=0, measuring how much they deviate from the center m*0It is known that the dropout can help regularise the model and lead to better generalisation, and with p<1 our (31-32) forms a principled optimal solution.


Returning to FIG. 2b, once the round is completed, the next round can begin with the random selection at step S200 and repeat all other steps. In this example, the trained parameters L0=(m0, V0) are output. The input is an initial L0=(m0, V0) in q(ϕ, L0)=custom-character(ϕ; {m0, V0, l0=|D|+1, n0=|D|+d+2}) where D=Σi=1N|Di| and d is the number of parameters in the backbone network p(y|x, θ).


Global prediction. Returning to FIG. 1c an important task in FL is global prediction. FIG. 3b illustrates schematically how global prediction may be applied in the present example. FIG. 3c is example pseudo code to implement the global prediction. The inner integral of (12) of the general predictive distribution becomes the multivariate Student-t distribution:













p

(

θ




"\[LeftBracketingBar]"

ϕ


)



q

(
ϕ
)


d

ϕ


=





𝒩

(


θ
;
μ

,
Σ

)

·

𝒩𝒥𝒲

(

ϕ
;

{


m
0

,

V
0

,

l
0

,

n
0


}


)



d

ϕ






(
33
)















=


t


n
0

-
d
+
1


(


θ
;

m
0


,



(


l
0

+
1

)



V
0




l
0

(


n
0

-
d
+
1

)



)


,





(
34
)







where tv(a, B) is the multivariate Student-t with location a, scale matrix b, and d.o.f. v. Then the predictive distribution for a new test input x* can be estimated as:










p

(


y
*





"\[LeftBracketingBar]"



x
*

,

D
1

,


,

D
N




)

=





p

(


y
*





"\[LeftBracketingBar]"



x
*

,
θ



)

·


t


n
0

-
d
+
1


(


θ
;

m
0


,



(


l
0

+
1

)



V
0




l
0

(


n
0

-
d
+
1

)



)



d

θ






(
35
)














1
S






s
=
1

S


p

(


y
*





"\[LeftBracketingBar]"



x
*

,

θ

(
s
)





)



,


where



θ

(
s
)







t



n
0


da

+
1


(


θ
;

m
0


,



(


l
0

+
1

)



V
0




l
0

(


n
0

-
d
+
1

)



)

.






(
36
)







where as shown in FIG. 3b, the inputs are the global model parameters L0=(m0, V0) from the FL trained model. These are input to the student-t sampler tv(a, B) to give an approximation θ(s) for the global parameters θ. Normally S=1 in practice. These approximate parameters are used in the deep model to predict the output y* for example a class label for the test input.


Personalisation. Returning to FIG. 1d an important task in FL is personalisation. FIG. 3d illustrates schematically how personalisation may be applied in the present example. FIG. 3e is example pseudo code for implementing personalisation. With the given personalisation training data Dp, we follow the general framework in (17) to find v(θ)≈p(θ|Dp, ϕ*) in a variational way, where ϕ* obtained from (24). For the density family for v(θ) we adopt the same spiky mixture form as (22), which leads to the MC-dropout-like learning objective similar to (25). Once v is trained, our predictive distribution follows the MC sampling (18). Thus, looking at FIG. 3c, we have the deep model:





p(yp|xp, Dp, D1:N)


where yp is the predicted output given the input xp, the personalized data Dp and all data sets D1:N. This can be approximated as







1
S






s
=
1

S


p

(


y
p





"\[LeftBracketingBar]"



x
p

,

θ

(
s
)





)






where θ(s) are the parameters from the MC-dropout and are defined by:





θ(s)˜v(θ; m)


where m is obtained from an optimizer which optimizes the equation below for the input global model parameters L0=(m0, V0)







min
m

-

log


p

(


D
i





"\[LeftBracketingBar]"



m
˜

i



)


+


p
2



(


n
0

+
d
+
1

)




(


m
i

-

m
0


)






V
0

-
1


(


m
i

-

m
0


)









and



v

(
θ
)


=



l


(


p
·

𝒩

(




θ
i

[
l
]

;


m
i

[
l
]


,


ϵ
2


I


)


+


(

1
-
p

)

·

𝒩

(




θ
i

[
l
]

;
0

,


ε
2


I


)



)






Mixture Model. Previously, the NIW prior model expresses our prior belief where each client i a priori gets its own network parameters θi as a Gaussian-perturbed version of the shared parameters 82 from the server, namely θi|ϕ˜custom-character(μ, Σ), as in (20). This is intuitively appealing, but not optimal for capturing more drastic diversity or heterogeneity of local data distributions across clients. In the situations where clients' local data distributions, as well as their domains and class label semantics, are highly heterogeneous (possibly even set up for adversarial purpose), it would be more reasonable to consider multiple different prototypes for the network parameters, diverse enough to cover the heterogeneity in data distributions across clients. Motivated from this idea, we introduce a mixture prior model as follows.


First we consider that there are K network parameters (prototypes) that broadly cover the client's data distributions. They are denoted as high-level latent variables, ϕ={μ1, . . . , μK} We consider:






p(ϕ)=πj=1Kcustom-characterj: 0,I)   (37)


where custom-character is a multivariate normal distribution, μj is the global random variable (also termed mean) for each network K. We here note some clear distinction from the NIW prior. Whereas the NIW prior (19) only controls the mean μ and covariance Σ in the Gaussian from which local models θi can take, the mixture prior (37) is far more flexible in covering highly heterogeneous distributions.


Each local model is then assumed to be chosen from one of these K prototypes. Thus the prior distribution forθi can be modeled as a mixture,











p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)

=


Σ

j
=
1

K



1
K



𝒩

(


θ
i

;

μ
j

;


σ
2


I


)



,




(
38
)







where σ is the hyperparameter that captures perturbation scale, and can be chosen by users or learned. Note that we put equal mixing proportions 1/K due to the symmetry, a priori. That is, each client can take any of μj's equally likely a priori.


We then describe our choice of the variational density q(ϕ)πiqii) to approximate the posterior p(ϕ, θ1 , . . . , θN|D1, . . , DN). First, qii) is chosen as a spiky Gaussian, in other words, it has a probability density which is concentrated in the mean value.






qi)=custom-characteri; mi, ϵ2),   (39)


with tiny ϵ, which corresponds to the MC-Dropout model with near-0 dropout probability. For q(ϕ) we consider a Gaussian factorized over μj's, but with near-0 variances, that is,






q(ϕ)=πj=1Kcustom-characterj; rj, ϵ2I),   (40)


where {rj}j=1K are variational parameters (L0) and ϵ is near 0 (e.g., 10−6). The main reason why we make q(ϕ) spiky is that the resulting near-deterministic q(ϕ) allows for computationally efficient and accurate MC sampling during ELBO optimization as well as test time (global) prediction, avoiding difficult marginalization. Although Bayesian inference in general encourages to keep as many plausible latent states as possible under the given evidence (observed data), we aim to retain this uncertainty by having many (possibly redundant) spiky prototypes μj's rather than imposing larger variances for individual ones (e.g., finite-sample approximation of a smooth distribution). Note that the number of prototypes K itself is a latent (hyper)parameter, and in principle one can achieve the same uncertainty effect by trade-off between K and ϵ: either small K with large ϵ or large K with small (near-0) ϵ. A gating network g=(x; β) is introduced to make client data dominantly explained by the most relevant model rj. The gating network is described in more detail below.


With the full specification of the prior distribution and the variational density family, we are ready to dig into the client objective function (8) and the server (9). FIG. 4a is an example of pseudo code which is suitable for implementing a training algorithm using the mixture case. The general steps of the training algorithm are the same as FIG. 2b. In other words, in a first step S200 of each round r, the server selects a predetermined number Nf(≤N) of client devices to participate in the round. In a next step S202, the server sends the parameters L0 of the global posterior q(ϕ; L0), to the participating client devices. These parameters are received at each client device at step S204.


Client update. In a next step S206, each local client device updates its local posterior qii; Li) using an appropriate technique. In this example, due to the spiky q(ϕ), we can accurately approximate the third term of (8) as:






custom-character
q(ϕ)q

i



i

)[log
pi|ϕ)]≈custom-characterqii)[logpi|ϕ*)], where ϕ*={μ*j=rj}j=1K. (41)


Then the last two terms of (8) boil down to KL(qii)∥p(θi|ϕ*)), which is the KL divergence between a Gaussian and a mixture of Gaussians. Since qii) is spiky, the KL divergence can be approximated with high accuracy using the single mode sample mi≠qii), that is,

















K


L
(



q
i

(

θ
i

)





"\[LeftBracketingBar]"



"\[RightBracketingBar]"




p
(

θ
i







"\[RightBracketingBar]"




ϕ
*


)

)




log



q
i

(

m
i

)


-

log


p
(

m
i







"\[RightBracketingBar]"




ϕ
*


)




(
42
)












=



-
log






j
=
1

K


𝒩

(



m
i

;

r
j


,


σ
2


I


)



+

const
.






(
43
)












=



-
log






j
=
1

K


exp

(

-






m
i

-

r
j




2


2


σ
2




)



+

const
.






(
44
)







Note here that we use the fact that mi disappears in logqi(mi). Plugging it into (8) yields the following optimization for client i:











min

m
i




𝔼


q
i

(

θ
i

)


[


-
log



p

(


D
i





"\[LeftBracketingBar]"


θ
i



)


]


-

log





j
=
1

K



exp

(

-






m
i

-

r
j




2


2


σ
2




)

.







(
45
)







This optimisation in (25) is then solved at step S206 to update custom-characteri for each client device. Since log-sum-exp is approximately equal to max, the regularization term in (45) focuses only on the closest global prototype rj from the current local model mi, which is intuitively well aligned with our initial modeling motivation, namely each local data distribution is explained by one of the global prototypes.


Lastly, we also note that in the SGD optimization setting where we can only access a minibatch B˜Di during the optimization of (45), we follow the conventional practice: replacing the first term of the negative log-likelihood by a stochastic estimate custom-characterqii)custom-character(x,y)˜B[−logp(y|x, θi] and multiplying the second term of regularization by







1



"\[LeftBracketingBar]"


D
i



"\[RightBracketingBar]"



.




In a next step S208, each local client device sends its updated local parameters custom-characteri=mi to the server.


Server update. The server receives each of the updated local posterior parameters Li from the client devices at step S210. At step S212, the global posterior is then updated using the received local posterior parameters. This is done using the optimization below which is derived as follows: First, the KL term in (9) can be easily derived as:






KL(q(ϕ)∥p(ϕ))=½Σj=1K∥rj2+const. (46)


Now, we can approximate the second term of (9) as follows:











𝔼


q

(
ϕ
)




q
i

(

θ
i

)



[

log


p

(


θ
i





"\[LeftBracketingBar]"

ϕ


)


]




𝔼

q

(
ϕ
)


[

log


p

(


m
i





"\[LeftBracketingBar]"

ϕ


)


]





(
47
)














log





j
=
1

K



1
K



𝒩

(



m
i

;

r
j


,


σ
2


I


)








(
48
)












=


log





j
=
1

K


exp

(

-






m
i

-

r
j




2


2


σ
2




)



+

const
.






(
49
)







where the approximations in (47) and (48) are quite accurate due to spiky qii) and q(ϕ), respectively. Combining the two terms leads to the optimization problem for the server:











min


{

r
j

}


j
=
1

K



1
2



Σ

j
=
1

K






r
j



2


-


Σ

i
=
1

N


log


Σ

j
=
1

K



exp

(

-






m
i

-

r
j




2


2


σ
2




)






(
50
)







This optimisation in (50) is then solved at step S212 to update custom-character0 at the server. The term σ2 in the denominator can be explained by incorporating an extra zero local model m0=0 (interpreted as a neutral model) with the discounted weight σ2 rather than 1.


Although (50) can be solved for K>1 by the standard gradient descent method, we apply the Expectation-Maximization (EM) algorithm instead. Using Jensen's bound with convexity of the negative log function, we have the following alternating steps:


E-step: With the current {rj}j−1K fixed, compute the prototype assignment probabilities for each local model mi:











c

(

j




"\[LeftBracketingBar]"

i


)

=


e


-





m
i

-

r
j




2


/

(

2


σ
2


)





Σ

j
=
1

K



e


-





m
i

-

r
j




2


/

(

2


σ
2


)






,




(
51
)







where λ is a small non-negative number (smoother) to avoid prototypes with no assignment.


M-step: With the current assignments c(j|i) fixed, solve:












min

{

r
j

}



1
2





j





r
j



2



+


1

2


σ
2








i
,
j




c

(

j




"\[LeftBracketingBar]"

i


)

·





m
i

-

r
j




2





,




(
52
)







which admits the closed form solution:











r
j
*

=



1
N



Σ

i
=
1

N




c

(

j




"\[LeftBracketingBar]"

i


)

·

m
i






σ
2

N

+


1
N



Σ

i
=
1

N



c

(

j




"\[LeftBracketingBar]"

i


)





,

j
=
1

,


,

K
.





(
53
)







The server update equation (53) has intuitive meaning that the new prototype rj becomes the (weighted) average of the local models mi's that are close to rj (those i's with non-negligible c(j|i)), which can be seen as an extension of the aggregation step in Fed-Avg to the multiple prototype case. However, (53) requires us to store all latest local models {mi}i=1N, which might be an overhead to the server. It can be more reasonable to utilize those up-to-date local models only that participated in the latest round. So, we use a stochastic approximate, (exponential) smoothed version of the update equation,











r
j
*





(

1
-
γ

)

·

r
j
old


+

γ
·



𝔼

i
~
R


[


c

(

j




"\[LeftBracketingBar]"

i


)

·

m
i


]





σ
2

N

+


𝔼

i
~
R


[

c

(

j




"\[LeftBracketingBar]"

i


)

]


,





,

j
=
1

,


,
K
,




(
54
)







where rjold is the prototype from the previous round, and γ is the smoothing weight.


Returning to FIG. 2b, once the round is completed, the next round can begin with the random selection at step S200 and repeat all other steps. In this example, the trained parameters L0={rj}j=1K and β are output. The input is an initial L0={rj}j=1K in q(ϕ, L0)=πj=1Kcustom-characterj; rj, ϵ2I) and β in the gaming network g=(x; β).


Global prediction. Returning to FIG. 1c an important task in FL is global prediction. FIG. 4b is example pseudo code to implement the global prediction in this example. By plugging the mixture prior p(θ|ϕ) of (38) and the factorized spiky Gaussian q(ϕ) of (40) into the inner integral of (12), we have predictive distribution averaged equally over {rj}j=1K approximately, that is,










p

(

θ




"\[LeftBracketingBar]"

ϕ


)



q

(
ϕ
)


d

ϕ





1
K



Σ

j
=
1

K




p

(


y
*





"\[LeftBracketingBar]"



x
*

,

r
j




)

.






Unfortunately this is not ideal for our original intention where only one specific model rj out of K candidates is dominantly responsible for the local data. To meet this intention, we extend our model so that the input point x* can affect θ together with ϕ, and with this modification our predictive probability can be derived as:










p

(


y
*





"\[LeftBracketingBar]"



x
*

,

D
1

,


,

D
N




)

=






p

(


y
*





"\[LeftBracketingBar]"



x
*

,
θ



)



p

(

θ




"\[LeftBracketingBar]"



x
*

,
ϕ



)



p

(

ϕ




"\[LeftBracketingBar]"



D
1

,


,

D
N




)


d

θ

d

ϕ








(
55
)





















p

(



y
*

|

x
*


,



θ


)



p

(


θ
|

x
*


,



ϕ


)



q

(
ϕ
)


d

θdϕ








(
56
)



















p

(



y
*

|

x
*


,



θ


)



p

(


θ
|

x
*


,





{

r
j

}


j
=
1

K



)


d

θ







(
57
)







To deal with the tricky part of inferring p(θ|x*, {rj}j=1K), we introduce a fairly practical strategy of fitting a gating function. The idea is to regard p(θ|x*, {rj}j=1K) as a mixture of experts where the prototypes rj's serving as experts,






p(θ|x*, {rj}j=1K): =Σj=1K, gj(x*)·δ(θ−rj),   (58)


where δ(·) is the Dirac's delta function, and g(x)ϵΔK−1 is a gating function that outputs a K-dimensional softmax vector. Intuitively, the gating function determines which of the K prototypes {rj}j=1K the model θ for the test point x* belongs to. With (58), the predictive probability in (57) is then written down as:






p(y*|x*, D1, . . . ,DN)≈Σj=1Kgj(x*) ·p(y*|x*,rj).   (59)


However, since we do not have this oracle g(x), we introduce and fit a neural network to the local training data during the training stage. Let g(x; β) be the gating network with the parameters β. To train it, we follow the Fed-Avg strategy. In the client update stage at each round, while we update the local model mi with a minibatch B˜Di, we also find the prototype closest to mi, namely j*:=argminj∥mi−rj∥. Then we form another minibatch of samples {(x,j*)}x˜B (input x and class label j*), and update g (x; /β) by SGD. The updated (local) β's from the clients are then aggregated (by simple averaging) by the server, and distributed back to the clients as an initial iterate for the next round.


Personalisation. Returning to FIG. 1d an important task in FL is personalisation. FIG. 4c is example pseudo code for implementing personalisation. With v(θ) of the same form as qii), the VI learning becomes similar to (45). Starting from p(θ|Dp, ϕ*) in the general framework (14-16), we define the variational distribution v(θ)≈p(θ|Dp, ϕ*) as:






v(θ)=custom-character(θ; m,ϵ2I),   (60)


where ϵ is tiny positive, and m is the only parameters that we learn. Our personalisation training amounts to ELBO optimisation for v(θ) as in (17), which reduces to:











min
m



𝔼

v

(
θ
)


[


-
log



p

(


D
p





"\[LeftBracketingBar]"

θ


)


]


-

log





j
=
1

K



exp

(

-





m
-

r
j




2


2


σ
2




)

.







(
61
)







Once we have optimal m (i.e., v(θ)), our predictive model becomes:





p(yp|xp, Dp, D1, . . . , DN)≈p(yp|xp, m),   (62)


which is done by feed-forwarding test input xp through the network deployed with the parameters m.


Theoretical Analysis

We provide two theoretical results for our Bayesian FL algorithm: Convergence analysis and Generalisation error bound. As a special block-coordinate optimisation algorithm, we show that it converges to an (local) optimum of the training objective (6); We theoretically show how well this optimal model trained on empirical data performs on unseen test data points. The computational complexity (including wall-clock running times) and communication cost of the proposed algorithms are analysed and summarised. Our methods incur only constant-factor extra cost compared to the minimal-cost FedAvg (“Communication-Efficient Learning of Deep Networks from Decentralized Data” by McMahan et al published in AI and Statistics (AISTATS) in 2017, reference [44]).


Convergence Analysis. Our (general) FL algorithm is a special block-coordinate SGD optimisation of the ELBO function in (6) with respect to the (N+1) parameter groups: L0 (of q(ϕ; L0)), L1 (of q 11; L1)), . . . , and LN (of qNN; LN)). In this section we will provide a theorem that guarantees convergence of the algorithm to a local minimum of the ELBO objective function under some mild assumptions. We will also analyse the convergence rate. Note that although our FL algorithm is a special case of the general block-coordinate SGD optimisation, we may not directly apply the existing convergence results for the regular block-coordinate SGD methods since they mostly rely on non-overlapping blocks with cyclic or uniform random block selection strategies. As the block selection strategy in our FL algorithm is unique with overlapping blocks and non-uniform random block selection, we provide our own analysis here. Promisingly, we show that in accordance with general regular block-coordinate SGD (cyclic/uniform non-overlapping block selection), our FL algorithm has 0(1/√{square root over (t)}) convergence rate, which is also asymptotically the same as that of the (holistic, non-block-coordinate) SGD optimisation. Note that this section is about the convergence of our algorithm to an (local) optimum of the training objective (ELBO). The question of how well this optimal model trained on empirical data performs on the unseen data points will be discussed in.


Theorem (Convergence analysis) We denote the objective function in (6) by f(x) where x=[x0, x1, xN] corresponding to the variational parameters x0:=L0, x1:=L1, . . . , xN:=LN. Let ηt=L+√{square root over (t)} for some constant L, and









x
¯

T

=


1
T



Σ

t
=
1

T



x
t



,




where t is the batch iteration counter, xt is the iterate at t by following our FL algorithm, and Nf(≤N) is the number of participating clients at each round. The following holds for any T:













𝔼
[

f

(


x
¯

T

)

]

-

f

(

x
*

)






N
+

N
f



N
f


·






T

+

L
¯


2



D
2


+


R
f
2



T



T



=

O

(

1

T


)


,




(
63
)







where x* is the (local) optimum, D and Rf are some constants, and the expectation is taken over randomness in minibatches and selection of participating clients.


The theorem states that xt converges to the optimal point x* in expectation at the rate of 0(1/√{square root over (t)}). This convergence rate asymptotically equals that of the conventional (non-block-coordinate, holistic) SGD algorithm.


Generalisation Error Bound. In this section we will discuss generalisation performance of our proposed algorithm, answering the question of how well the Bayesian FL model trained on empirical data performs on the unseen data points. We aim to provide the upper bound of the generalisation error averaged over the posterior distribution of the model parameters (ϕ, {θi}i=1N), by linking it to the expected empirical error with some additional complexity terms.


To this end, we first consider the PAC-Bayes bounds naturally because they have similar forms relating the two error terms (generalization and empirical) expected over the posterior distribution via the KL divergence term between the posterior and the prior distributions. However, the original PAC-Bayes bounds have the square root of the KL in the bound, which deviates from the ELBO objective function that has the sum of the expected data loss and the KL term as it is (instead of the square root). However, there are some recent variants of PAC-Bayes bounds, specifically the PAC-Bayes-λ bound, which removes the square root of the KL and suits better with the ELBO objective function.


To discuss it further, the objective function of our FL algorithm (6) can be viewed as a conventional variational inference ELBO objective with the prior p(B) and the posterior q(B), where B={ϕ, θ1, . . . , θN} indicates the set of all latent variables in our model. More specifically, the negative ELBO (function of the variational posterior distribution q) can be written as:












-
ELBO



(
q
)


=



𝔼

q

(
B
)


[



l
^

n



(
B
)


]

+


1
n



KL

(


q

(
B
)





"\[LeftBracketingBar]"



"\[RightBracketingBar]"




p

(
B
)


)




,




(
64
)







where {circumflex over (l)}n(B) is the empirical error/loss of the model B on the training data of size n. We then apply the PAC-Bayes-λ bound; for any 2 E (0,2), the following holds with probability at least 1−δ:












𝔼

q

(
B
)


[

l

(
B
)

]





1

1
-

λ
/
2






𝔼

q

(
B
)


[



l
^

n



(
B
)


]


+


1

λ

(

1
-

λ
/
2


)






KL

(


q

(
B
)





"\[LeftBracketingBar]"



"\[RightBracketingBar]"




p

(
B
)


)

+

log

(

2



n

/
δ


)


n




,




(
65
)







where 1(B) is the generalisation error/loss of the model B. Thus, when λ=1, the right hand side of (65) reduces to −2·ELBO(q) plus some complexity term, justifying why maximizing ELBO with respect to q can be helpful for reducing the generalisation error. Although this argument may look partially sufficient, but strictly saying, the extra factor 2 in the ELBO (for the choice λ=1) may be problematic, potentially making the bound trivial and less useful. Other choice of λ fails to recover the original ELBO with slightly deviated coefficients for the expected loss and the KL.


In what follows, we state our new generalisation error bound for our FL algorithm, which does not rely on the PAC-Bayes but the recent regression analysis technique for variational Bayes, which was also adopted in the analysis of some personalised FL algorithm recently.


Theorem (Generalisation error bound) Assume that the variational density family for qii) is rich enough to subsume Gaussian. Let d2(Pθi, Pi) be the expected squared Hellinger distance between the true class distribution Pi(y|x) and model's Pθi(y|x) for client i's data. The optimal solution ({q*ii)}i−1N, q*(ϕ)) of our FL-ELBO optimisation problem (6) satisfies:














1
N








i
=
1

N




𝔼


q
i
*

(

θ
i

)


[


d
2

(


P

θ
i


,

P
i


)

]





O

(

1
n

)

+

C
·

ϵ
n
2


+


C


(


r
n

+


1
N








i
=
1

N



λ
i
*



)



,




(
66
)








with high probability, where C, C′>0 are constant, λ*i=minθϵΘ∥fθ−fi2 is the best error within our backbone network family Θ, and rn, ϵn→0 as the training data size n→∞.


This theorem implies that the optimal solution for our FL-ELBO optimisation problem (attainable by our block-coordinate FL algorithm) is asymptotically optimal, since the right hand side of (66) converges to 0 as the training data size n→∞. Note that the last term









1
N







i



λ
i
*






can be made arbitrarily close to 0 by increasing the backbone capacity (MLPs as universal function approximators). But practically for fixed n, as enlarging the backbone capacity also increases ϵn and rn, it is important to choose the backbone network architecture properly. Note also that our assumption on the variational density family for qii) is easily met; for instance, the families of the mixtures of Gaussians adopted in NIW and mixture models obviously subsume a single Gaussian family.


Evaluation

We evaluate the proposed hierarchical Bayesian models on several FL benchmarks: CIFAR-100, MNIST, Fashion-MNIST, and EMNIST. We also have results on the challenging corrupted CIFAR (CIFAR-C-100) that renders the client data more heterogeneous both in input images and class distributions. Our implementation is based on “FedBABU: Towards Enhanced Representation for Federated Image Classification” by Babu et al published in International Conference on Learning Representations, 2022, reference [45]) where MobileNet (described in “MobileNets: Efficient convolutional neural networks for mobile vision applications” published by Howaard et al in arXiv preprint arXiv:1704.04861, 2017.) is used as a backbone. The implementations follow the body-update strategy: the classification head (the last layer) is randomly initialized and fixed during training, with only the network body updated (and both body and head updated during personalisation). We report results all based on this body-update strategy since we observe that it considerably outperforms the full update for our models and other competing methods. The hyperparameters are: (NIW) ϵ=10−4 and p=1−0.001 (See ablation study below for other values); (Mixture) σ2=0.1, ϵ=10−4, mixture order K=2, and the gating network has the same architecture as the main backbone, but the output cardinality changed to K. Other hyperparameters including batch size (50), learning rate (0.1 initially, decayed by 0.1) and the number of epochs in personalisation (5), are the same as those in the FedBABU paper.


Personalisation (CIFAR-100): Specifically, we are given a training split of the personalized data to update the FL-trained model. Then we measure the performance of the adapted model on the test split that conforms to the same distribution as the training split. Following the FedBABU paper, the client data distributions are heterogeneous non-iid, formed by the sharding-based class sampling (described in “Efficient Learning of Deep Networks from Decentralized Data” by McMahan el al published in 2017 in AI and Statistics (AISTATS).). More specifically, we partition data instances in each class into non-overlapping equal-sized shards, and assign s randomly sampled shards (over all classes) to each of N clients. Thus the number of shards per user s can control the degree of data heterogeneity: small s leads to more heterogeneity, and vice versa. The number of clients N=100 (each having 500 training, 100 test samples), and we denote by f the fraction of participating clients. So, Nf=N·f clients are randomly sampled at each round to participate in training. Smaller f makes the FL more challenging, and we test two settings: f=1.0 and 0.1. Lastly, the number of epochs for client local update at each round is denoted by τ where we test τ=1 and 10, and the number of total rounds is determined by τ as 320/τ for fairness. Note that smaller r incurs more communication cost but often leads to higher accuracy. For the competing methods FedBE (“FedBE: Making Bayesian Model Ensemble Applicable to Federated Learning” by Chen et al published in International Conference on Learning Representations, 2021, reference [16]) and FedEM (“Federated Multi-Task Learning under a Mixture of Distributions” by Marfoq et al published in Advances in Neural Information Processing Systems, 2021, reference [41]), we set the number of ensemble components or base models to 3. For FedPA (described in “Federated Learning via Posterior Averaging: A New Perspective and Practical Algorithms” by Shedivat published in International Conference on Learning Representations, 2021, reference [4]): shrinkage parameter ρ=0.01.


MNIST/F-MNIST/EMNIST. Following the standard protocols, we set the number of clients N=100, the number of shards per client s=5, the fraction of participating clients per round f=0.1, and the number of local training epochs per round τ=1 (total number of rounds 100) or 5 (total number of rounds 20) for MNIST and F-MNIST. For EMNIST, we have N=200, f=0.2, τ=1 (total number of rounds 300). We follow the standard Dirichlet-based client data splitting. For the competing methods FedBE and FedEM, we use three-component models. The backbone is an MLP with a single hidden layer with 256 units for MNIST/F-MNIST, while we use a standard ConvNet with two hidden layers for EMNIST.


Main results and interpretation. In FIGS. 5a to 5e, we compare our methods (NIW and Mixture with K=2) against the popular FL methods, including FedAvg, FedBABU, FedProx (described in Tian Li, Anit Kumar Sahu, Manzi' Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith. “Federated Optimization in Heterogeneous Networks” by Li et al published in arXiv preprint arXiv:1812.06127, 2018, reference [34]), as well as recent Bayesian/ensemble methods, FedPA, FedBE, pFedBayes. FedEM, and FedPop (described in FedPop: “A Bayesian Approach for Personalised Federated Learning” by Kotelevskii et al published in Advances in Neural Information Processing Systems, 2022, reference [32]. We run the competing methods (implementation based on their public codes or our own implementation if unavailable) with default hyperparameters (e.g., μ=0.01 for FedProx) and report the results. First of all, our two models (NIW and Mix.) consistently perform the best (by large margins most of the time) in terms of both global prediction and personalisation for nearly all FL settings on the two datasets. This is attributed to the principled Bayesian modelling of the underlying FL data generative process in our approaches that can be seen as rigorous generalisation and extension of the existing intuitive algorithms such as FedAvg and FedProx. In particular, the superiority of our methods to the other Bayesian/ensemble approaches verifies the effectiveness of modelling client-wise latent variables θi against the commonly used shared θ modeling.


CIFAR-100 Corrupted (CIFAR-C-100). About the dataset: CIFAR's test split (10K images) are corrupted by 19 different types of noise processes (e.g., Gaussian, motion blur, JPEG). For each corruption type, there are 5 different corruption levels, and we use the severest one. The CIFAR-100-Corrupted dataset (published in “Benchmarking neural network robustness to common corruptions and perturbations” by Hendrycks et al published in International Conference on Learning Representations, 2019.) makes CIFAR-100′s test split (10K images) corrupted by 19 different noise processes (e.g., Gaussian, motion blur, JPEG). For each corruption type, there are 5 corruption levels, and we use the severest one. Randomly chosen 10 corruption types are used for training (fixed) and the rest 9 types for personalisation. We divide N=100 clients into 10 groups, each group assigned one of the 10 training corruption types exclusively (denoted by Dc the corrupted data for the group c=1, . . . ,10). Each Dc is partitioned into 90%/10% training/test splits, and clients in each group (N/10 clients) gets non-iid train/test subsets from Dc's train/test splits by following the sharding strategy with s=100 or 50. This way, the clients in different groups have considerable distribution shift in input images, while there also exists heterogeneity in class distributions even within the same groups.


For the FL-trained models, we evaluate global prediction on two datasets: clients' test splits from the 10 training corruption types and the original (uncorrupted) CIFAR's training split (50K images). For personalisation, we partition the clients into 9 groups, and assign one of the 9 corruption types to each group exclusively. Within each group we form non-iid sharding-based subsets similarly, and again we split the data into the 90% training/finetuning split and 10% test. Note that this personalisation setting is more challenging compared to CIFAR-100 since the data for personalisation are utterly unseen during the FL training stage. We test τ=1 and 4 scenarios. We test sharding parameter s=100 or 50, participating client fraction f=1.0 or 0.1, and the number of local epochs τ=1 and 4 scenarios where the results are reported in Table 3 of FIG. 5e. As shown, our two models (NIW and Mix.) consistently perform the best (by large margins most of the time) in terms of both global prediction and personalisation for all FL settings. This is attributed to the principled Bayesian modeling of the underlying FL data generative process in our approaches. In particular, the superiority of our methods to the other Bayesian approach pFedBayes verifies the effectiveness of modeling client-wise latent variables θi against the commonly used shared θ modeling, especially for the scenarios of significant client data heterogeneity (e.g., personalisation on data with unseen corruption types). Our methods are especially robust for the scenarios of significant client data heterogeneity, e.g., CIFAR-C-100 personalisation on data with unseen corruption types as shown in the tables of FIG. 5e.


(Ablation) Hyperparameter sensitivity. We test sensitivity to some key hyperparameters in our models. For NIW, we have p=1−pdrop, the MC-dropout probability, where we used pdrop=0.001 in the main experiments. In FIGS. 6a and 6b we report the performance of NIW for different values (pdrop=0.10−4,10−2) on CIFAR-100 with (s=100, f=0.1, τ=1) setting. We see that the performance is not very sensitive to pdrop unless it is too large (e.g., 0.01).


For the Mixture model, different mixture orders K=2,5,10 are contrasted in FIGS. 6c and 6d. More information is also provided in the tables shown in FIGS. 6e and 6f. As explained above, our Mixture model maintains K backbone networks, where the mixture order is usually small but greater than 1 (e.g. K=2). Thus it requires additional computational resources than other methods (including our NIW model) that only deal with a single backbone. FIGS. 6c and 6d thus compares the Mixture model with a preset baseline extension and an ensemble baseline extension to Fed-BABU which are generated as follows.


In a first step, the server maintains K networks (denoted by (θ1, . . . θK). In a second step, we partition the client devices into K groups with equal proportions. We assign θj to each group j (j=1, . . . , K). At each round, each participating client device i receives the current model θj(i) from the server, where j(i) means the group index to which client device i belongs. The client devices perform local updates as usual by warm-start with the received models, and send the updated models back to the server. The server then collects the updated local models from the client devices, and takes the average within each group j to update θj.


After training, we have trained K networks. At test (inference) time, we can use these K networks in two different ways. In a first option termed a Preset Baseline, each client device i uses the network assigned to its group, e.g. θj(i), for both prediction and finetuning/personalisation. In a second option termed an Ensemble Baseline, we use all K networks for prediction and finetuning.


As seen, having more mixture components does no harm (no overfitting), but we do not see further improvement over K=2 in our experiments. In the last columns of the tables, we also report the performance of the centralised (non-FL) training in which batch sampling follows the corresponding FL settings. That is, at each round, the minibatches for SGD (for conventional cross-entropy loss minimisation) are sampled from the data of the participating clients. The centralised training sometimes outperforms the best FL algorithms (our models), but can fail completely especially when data heterogeneity is high (small s) and τ is large. This may be due to overtraining on biased client data for relatively few rounds. Our FL models perform well consistently and stably being comparable to centralised training on its ideal settings (small τ and large s).



FIGS. 7a to 7d are tables comparing the complexity of the proposed algorithms with FedAvg. In FIG. 7a, the training complexity is compared. All quantities are per-round, per-batch and per-client costs. In the entries, d=the number of parameters in the backbone network, F=time for feed-forward pass, B=time for backpropagation and Nf is the number of participating clients per round. In FIG. 7b, the global prediction complexity is compared. All quantities are per-test-batch costs. In the entries, d=the number of parameters in the backbone network, F=time for feed-forward pass, and S=the number of samples θs from the student-t distribution in the NIW case (we use S=1). In FIG. 7c, the personalisation complexity is compared. All quantities are per-train/per-batch costs. In the entries, d=the number of parameters in the backbone network, F=time for feed-forward pass and B=time for backpropagation. In FIG. 7d, we compare the running times in seconds on CIFAR-100 with s=100, f=1.0 and τ=1. These measurements are done on the same machine a Xeon 2.20 GHz CPU with a single RTX 2080 Ti GPU.


As shown in FIGS. 7a to 7d, although the models proposed in the present techniques achieve significant improvement in prediction accuracy, there is extra computational overhead compared to simpler federated learning methods such as Fed-BABU. For the NIW method, the extra cost in the local client update and personalisation (training) originates from the penalty term in equation (25), while model weight squaring to compute VO in equation (32) incurs additional cost in the server update. For the Mixture method, the increased time in training is mainly due to the overhead of computing distances from the K server models in equation (45) and (50). However, overall the extra costs are not prohibitively large, rendering our methods sufficiently practical.



FIG. 8 is a block diagram of a system 300 for using federated learning to train/update a ML model 310. The system comprises a server 302 and a plurality of client devices 312. Only one client device 312 is shown here for the sake of simplicity, but it will be understood that there may be tens, hundreds, thousands, or more, client devices.


The system 300 comprises a server 302 for training a global machine learning, ML, model using federated learning. The server 302 comprises: at least one processor 304 coupled to memory 306. The at least one processor 304 may be arranged to: receive at least one network weight to the server from each client device; link the at least one network weight to a higher-level variable; and train the ML model 310 using the set of training data 308 to optimise a function dependent on the weights from the client devices and a function dependent on the higher-level variable.


To train/update the global model 310 using the data received from the client devices, the server may: analyse a loss function associated with the ML model to determine whether a posterior needs to be updated; and train the ML model by updating the posterior to put mass on the higher-level variables that have high compatibility scores with the network weights from the client devices, and to be close to a prior. Information about the posterior and prior is provided above.


The system 300 comprises a client device 312 for locally training a local version 318 of the global ML model using local/personal training data 320.


The client device comprises at least one processor 314 coupled to memory 316 arranged to: receive the updated posterior from the server; and train, using the updated posterior and the set of personal training data, the local version of the ML model to minimize a class prediction error in its own data and to be close to the current global standard.


The client device 312 may be any one of: a smartphone, tablet, laptop, computer or computing device, virtual assistant device, a robot or robotic device, a robotic assistant, image capture system or device, an Internet of Things device, and a smart consumer device. It will be understood that this is a non-limiting and non-exhaustive list of apparatuses.


The at least one processor 314 may comprise one or more of: a microprocessor, a microcontroller, and an integrated circuit. The memory 316 may comprise volatile memory, such as random access memory (RAM), for use as temporary memory, and/or non-volatile memory such as Flash, read only memory (ROM), or electrically erasable programmable ROM (EEPROM), for storing data, programs, or instructions, for example.


Comparison with known Bayesian or ensemble FL approaches.


Some recent studies tried to tackle the FL problem using Bayesian or ensemble-based methods. As we mentioned earlier, the key difference is that most methods do not introduce Bayesian hierarchy in a principled manner. Instead, they ultimately treat network weights θ as a random variable shared across all clients. On the other hand, our approach assigns individual θi to each client i governed by a common prior p(θi|ϕ). The non-hierarchical approaches mostly resort to ad hoc heuristics and/or strong assumptions in their algorithms. For instance, FedPA (described in “Federated Learning via Posterior Averaging: A New Perspective and Practical Algorithms” by Al-Shedivat et al published in International Conference on Learning Representations, 2021) aims to establish the product-of-experts decomposition, p(θ|D1:N)∝πi=1Np(θ|Di) to allow client-wise inference of p(θ|Di). However, this decomposition does not hold in general unless a strong assumption of uninformative prior p(θ)∝1 is made.


Other approaches include FedBE (Fed Bayesian Ensemble) which is described in “FedBE: Making Bayesian Model Ensemble Applicable to Federated Learning” by Chen et al published in International Conference on Learning Representations, 2021 aims to build the global posterior distribution p(θ|D1:N) from the individual posteriors p(θ|Di) in some ad hoc ways. FedEM described in (“Federated Multi-Task Learning under a Mixture of Distributions” by Marfoq, et al published in Advances in Neural Information Processing Systems, 2021) forms a seemingly reasonable hypothesis that local client data distributions can be identified as mixtures of a fixed number of base distributions (with different mixing proportions). Although they have sophisticated probabilistic modeling, this method is not a Bayesian approach. FedBayes described in “Personalized Federated Learning via Variational Bayesian Inference” by Chang et al published in the 2022 International Conference on Machine Learning can be seen as an implicit regularisation-based method to approximate p(θ|D1:N) from individual posteriors p(θ|Di). To this end, they introduce the so-called global distribution w(θ), which essentially serves as a regulariser to prevent local posteriors from deviating from it. The introduction of w(θ) and its update strategy appears to be a hybrid treatment rather than solely Bayesian perspective. Finally, FedPop described in “FedPop: A Bayesian Approach for Personalised Federated Learning” by Kotelevskii et al published in Advances in Neural Information Processing Systems, 2022 has a similar hierarchical Bayesian model structure to the method described above, but their model is limited to a linear deterministic model for the shared variate.


Other Bayesian FL algorithms. Other recent Bayesian methods adopt the expectation-propagation (EP) approximations for example as described in “Federated Learning as Variational Influence: A Scalable Expectation Propagation Approach” by Guo et al published in International Conference on Learning Representation 2023 or “Partitioned Variational Inference: A framework for probabilistic federated learning” by Guo et al published in 2022. In particular, the EP update steps are performed locally with the client data. However, neither of these two works is a hierarchical Bayesian model—unlike our individual client modelling, they have a single model θ shared across clients, without individual modeling for client data, thus following FedPA-like inference p(θ|D1:N). The consequence is that they lack a systematic way to distinctly model global and local parameters for global prediction and personalised prediction respectively.


According to an embodiment of the disclosure, a method for training, using federated learning, a global machine learning, ML, model for use by a plurality of client devices, may comprise: defining, at a server, a Bayesian hierarchical model which links a global random variable with a plurality of local random variables, one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices.


In an embodiment, the method may further comprise: approximating, at the server, the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the server, each of the plurality of local ML models is associated with one of the plurality of client devices and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model.


In an embodiment, the method may further comprise: sending, from the server, the global parameter to a predetermined number of the plurality of client devices.


In an embodiment, the method may further comprise: receiving, at the server from each of the number of the plurality of client devices, an updated local parameter, wherein each updated local parameter has been determined by training, on the client device, the local ML model using a local dataset, and wherein during training of the local ML model, the global parameter is fixed.


In an embodiment, the method may further comprise: training, at the server, the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed.


In an embodiment, at least some of the local ML models and/or global ML model may have different structures.


In an embodiment, training the global ML model may comprise optimising using a regularization term which penalises deviation between the updated global parameter and the global parameter which was sent to the client devices.


In an embodiment, training the global ML model may comprise optimising using a regularization term which penalises deviation between the updated global parameter and each of the local parameters received from the plurality of client devices.


In an embodiment, approximating the posterior distribution may comprise using a Normal-Inverse-Wishart model as the global ML model and using a global mean parameter and a global covariance parameter as the global parameter and using a mixture of two Gaussian functions as the local ML model and a local mean parameter as the local parameter.


In an embodiment, approximating the posterior distribution may comprise using a mixture model which comprises multiple different prototypes and each prototype is associated with a separate global random variable.


In an embodiment, the method may further comprise using a product of multiple multivariate normal distributions as the global model and using variational parameters as the global parameter and using one of the multiple multivariate normal distributions as the local ML model and a local mean parameter as the local parameter.


In an embodiment, training, on the client device, may comprise optimising using a loss function to fit each local parameter to the local dataset.


In an embodiment, training, on the client device, may comprise optimising using a regularisation term which penalises deviation between each updated local parameter and a previous local parameter.


According to an embodiment of the disclosure, a method for generating, using a client device, a personalised model using a global machine learning, ML, model which has been trained at a server, may comprise: receiving, at the client device from the server, a global parameter for the trained global ML model.


In an embodiment, the method may further comprise: optimising, at the client device, a local parameter using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter.


In an embodiment, the method may further comprise: outputting the optimised local parameter as the personalised model.


In an embodiment, the method may further comprise: obtaining a set of personal data, wherein optimising the local parameter using the received global parameter comprises optimising the local parameter using the received global parameter, by applying a loss function over the set of personal data.


In an embodiment, the method may further comprise: receiving an input; and predicting, using the personalised model, an output based on the received input.


According to an embodiment of the disclosure, an electronic device for training, using federated learning, a global machine learning, ML, model for use by a plurality of client devices, may comprise at least one processor coupled to memory.


In an embodiment, the at least one processor may be configured to: define a Bayesian hierarchical model which links a global random variable with a plurality of local random variables, one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices.


In an embodiment, the at least one processor may be configured to: approximate the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the electronic device, each of the plurality of local ML models is associated with one of the plurality of client devices and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model.


In an embodiment, the at least one processor may be configured to: send the global parameter to a predetermined number of the plurality of client devices.


In an embodiment, the at least one processor may be configured to: receive, from each of the number of the plurality of client devices, an updated local parameter, wherein each updated local parameter has been determined by training, on the client device, the local ML model using a local dataset, and wherein during training of the local ML model, the global parameter is fixed.


In an embodiment, the at least one processor may be configured to: train the global ML model (310) using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed.


In an embodiment, the at least one processor may be configured to: optimise using a regularization term which penalises deviation between the updated global parameter and the global parameter which was sent to the client devices.


In an embodiment, the at least one processor may be configured to: optimise using a regularization term which penalises deviation between the updated global parameter and each of the local parameters received from the plurality of client devices.


According to an embodiment of the disclosure, a system for training, using federated learning, a global machine learning, ML, model, may comprise: a server comprising a processor coupled to memory, and a plurality of client devices each comprising a processor coupled to memory, wherein the processor at the server is configured to: define a Bayesian hierarchical model which links a global random variable with a plurality of local random variables, one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices; approximate the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the server, each of the plurality of local ML models is associated with one of the plurality of client devices and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model; send the global parameter to a predetermined number of the plurality of client devices; receive, from each of the predetermined number of the plurality of client devices, an updated local parameter; train the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed; and wherein the processor at each of the client devices is configured to: receive the global parameter; train the local ML model using a local dataset on the client device to determine an updated local parameter, wherein during training of the local ML model, the global parameter is fixed.


According to an embodiment of the disclosure, a client device may comprise a processor coupled to memory, wherein the processor is configured to: receive, from a server, a global parameter for a trained global ML model which has been trained at the server; determine whether there is a set of personal data on the client device; when there is no set of personal data, optimise a local parameter of the local ML model using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter, and when there is a set of personal data, optimise a local parameter of the local ML model using the received global parameter by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter and by applying a loss function over the set of personal data; outputting the optimised local parameter as a personalised model; and predicting, using the personalised model, an output based on a newly received input.


Those skilled in the art will appreciate that while the foregoing has described what is considered to be the best mode and where appropriate other modes of performing present techniques, the present techniques should not be limited to the specific configurations and methods disclosed in this description of the preferred embodiment. Those skilled in the art will recognize that present techniques have a broad range of applications, and that the embodiments may take a wide range of modifications without departing from any inventive concept as defined in the appended claims.

Claims
  • 1. A method for training, using federated learning, a global machine learning, ML, model for use by a plurality of client devices, the method comprising: defining, at a server, a Bayesian hierarchical model which links a global random variable with a plurality of local random variables, one for each of the plurality of client devices, wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices;approximating, at the server, the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the server, each of the plurality of local ML models is associated with one of the plurality of client devices and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model;sending, from the server, the global parameter to a predetermined number of the plurality of client devices;receiving, at the server from each of the number of the plurality of client devices, an updated local parameter, wherein each updated local parameter has been determined by training, on the client device, the local ML model using a local dataset, and wherein during training of the local ML model, the global parameter is fixed; andtraining, at the server, the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed.
  • 2. The method of claim 1, wherein at least some of the local ML models and/or global ML model have different structures.
  • 3. The method of claim 1, wherein training the global ML model comprises optimising using a regularization term which penalises deviation between the updated global parameter and the global parameter which was sent to the client devices.
  • 4. The method of claim 1, wherein training the global ML model comprises optimising using a regularization term which penalises deviation between the updated global parameter and each of the local parameters received from the plurality of client devices.
  • 5. The method of claim 1, wherein approximating the posterior distribution comprises using a Normal-Inverse-Wishart model as the global ML model and using a global mean parameter and a global covariance parameter as the global parameter and using a mixture of two Gaussian functions as the local ML model and a local mean parameter as the local parameter.
  • 6. The method of claim 1, wherein approximating the posterior distribution comprises using a mixture model which comprises multiple different prototypes and each prototype is associated with a separate global random variable.
  • 7. The method of claim 6, further comprising using a product of multiple multivariate normal distributions as the global model and using variational parameters as the global parameter and using one of the multiple multivariate normal distributions as the local ML model and a local mean parameter as the local parameter.
  • 8. The method of claim 1, wherein training, on the client device, comprises optimising using a loss function to fit each local parameter to the local dataset.
  • 9. The method of claim 1, wherein training, on the client device, comprises optimising using a regularisation term which penalises deviation between each updated local parameter and a previous local parameter.
  • 10. A method for generating, using a client device, a personalised model using a global machine learning, ML, model which has been trained at a server, the method comprising: receiving, at the client device from the server, a global parameter for the trained global ML model;optimising, at the client device, a local parameter using the received global parameter, by applying a regularisation term which penalises deviation between the optimised local parameter and the received global parameter, andoutputting the optimised local parameter as the personalised model.
  • 11. The method of claim 10, further comprising: obtaining a set of personal data,wherein optimising the local parameter using the received global parameter comprises optimising the local parameter using the received global parameter, by applying a loss function over the set of personal data.
  • 12. The method of claim 10, further comprising: receiving an input; andpredicting, using the personalised model, an output based on the received input.
  • 13. An electronic device for training, using federated learning, a global machine learning, ML, model for use by a plurality of client devices, the electronic device comprising at least one processor coupled to memory, wherein the at least one processor is configured to: define a Bayesian hierarchical model which links a global random variable with a plurality of local random variables, one for each of the plurality of client devices (312), wherein the Bayesian hierarchical model comprises a posterior distribution which is suitable for predicting the likelihood of the global random variable and the local random variables given each individual dataset at the plurality of client devices;approximate the posterior distribution using a global ML model and a plurality of local ML models, wherein the global ML model is parameterised by a global parameter which is updated at the electronic device, each of the plurality of local ML models is associated with one of the plurality of client devices and each local ML model is parameterised by a local parameter which is updated at the client device which is associated with the local ML model;send the global parameter to a predetermined number of the plurality of client devices;receive, from each of the number of the plurality of client devices, an updated local parameter, wherein each updated local parameter has been determined by training, on the client device, the local ML model using a local dataset, and wherein during training of the local ML model, the global parameter is fixed; andtrain the global ML model using each of the received updated local parameters to determine an updated global parameter, wherein during training of the global model, each local parameter is fixed.
  • 14. The electronic device of claim 13, wherein the at least one processor is configured to: optimise using a regularization term which penalises deviation between the updated global parameter and the global parameter which was sent to the client devices.
  • 15. The electronic device of claim 13, wherein the at least one processor is configured to: optimise using a regularization term which penalises deviation between the updated global parameter and each of the local parameters received from the plurality of client devices.
  • 16. The electronic device of claim 13, wherein the at least one processor is configured to: use a Normal-Inverse-Wishart model as the global ML model and use a global mean parameter and a global covariance parameter as the global parameter and use a mixture of two Gaussian functions as the local ML model and a local mean parameter as the local parameter.
  • 17. The electronic device of claim 13, wherein the at least one processor is configured to: use a mixture model which comprises multiple different prototypes and each prototype is associated with a separate global random variable.
  • 18. The electronic device of claim 17, wherein the at least one processor is further configured to: use a product of multiple multivariate normal distributions as the global model and use variational parameters as the global parameter and use one of the multiple multivariate normal distributions as the local ML model and a local mean parameter as the local parameter.
  • 19. The electronic device of claim 13, wherein the at least one processor is configured to: optimise using a regularisation term which penalises deviation between each updated local parameter and a previous local parameter.
  • 20. A non-transitory storage media storing a computer program that, when executed by at least one processor, causes the at least one processor to perform the method of claim 1.
Priority Claims (2)
Number Date Country Kind
2214033.9 Sep 2022 GB national
23198714.0 Sep 2023 EP regional
CROSS-REFERENCE TO RELATED APPLICATION(S)

This is a bypass continuation of PCT/KR2023/014863 filed on Sep. 26, 2023, which claims benefit of GB 2214033.9 filed Sep. 26, 2022 and EP 23198714.0 filed September 21, 2023, the disclosures of which are incorporated by reference in their entirety.

Continuations (1)
Number Date Country
Parent PCT/KR2023/014863 Sep 2023 US
Child 18512195 US