The disclosure relates to the field of machine learning (ML), particularly to a model training method and a face recognition method based on adaptive split learning-federated learning.
In recent years, ML has achieved remarkable performance in numerous tasks, such as computer vision, natural language processing, and speech recognition, owing to its excellent representation and learning capabilities. For instance, the face recognition technology based on ML is widely applied in fields like smart homes and security surveillance.
ML typically requires a substantial amount of data and computation resources to train models with good generalization performance. As a result, the approach of centralized learning has been widely adopted. In the centralized learning, a central server possesses vast amounts of data and trains models using abundant computation resources. However, in the face recognition task, the training data is generated by users, submitting this raw data (i.e., training data) to the central server located in the cloud can compromise user privacy. With the exponential growth in computation and storage capabilities of user devices, it becomes feasible to utilize local resources for learning tasks. Therefore, federated learning (FL), proposed by Google® in 2016, has garnered widespread attention and has been applied to some face recognition tasks.
In FL, ML models are trained on the user devices while maintaining the localization of model data. By submitting updated local gradients to the central server for aggregation rather than submitting the raw data, user privacy can be protected to some extent. However, the user devices participating in FL training may exhibit heterogeneity issues, that is, significant variations in the computation capability, battery life, and data distribution, which can affect the efficiency of FL. Additionally, the user data can still be at risk of privacy leakage in the face recognition tasks with traditional FL, as it can be reconstructed by eavesdropping on the transmitted model weights or gradients.
In response to the aforementioned challenges, existing solutions are incapable of addressing the heterogeneity issue of participating user devices if the data on each user device is assumed to be independent and identically distributed (IID). In the face recognition tasks, if straggled devices (i.e., laggards) due to computation resource or energy constraints are not properly scheduled, the FL model may exhibit biases, leading to suboptimal performance on those laggards (particularly those with limited network bandwidth or access restrictions that prevent continued training). Moreover, due to various reasons such as sensor placement, some user devices may consistently possess “more important” data. Thus, it is crucial to include these devices' data in the training, even if their computation capabilities are weak, current solutions fail to effectively resolve this issue. Additionally, to enhance the privacy protection in the face recognition tasks, most existing solutions have developed additional mechanisms, often at the cost of system efficiency or model performance.
Therefore, devising a model training method that addresses the heterogeneity of participating user devices and strengthens privacy protection for face recognition tasks, and subsequently utilizing the trained model to achieve accurate face recognition, is an urgent problem that needs to be addressed.
To solve the above problems in the related art, the disclosure provides a model training method and a face recognition method based on adaptive split learning-federated learning. The technical problems to be solved by the disclosure are realized by the following technical solutions.
In a first aspect, an embodiment of the disclosure provides a model training method based on adaptive split learning-federated learning, applied to a ring structured federated learning (RingSFL) system including: a server and multiple user terminals; and the model training method includes:
In an embodiment, the steps of uploading, by each user terminal, device information thereof to the server, and allocating, by the server, a respective propagation step length and a respective aggregation weight to each user terminal based on the device information obtained from the user terminals, include:
In an embodiment, the pre-established optimization problem regarding computation time, includes:
represents a total computation capability value of the user terminals; ci represents a ratio of the computation capability value of the user terminal ui to the total computation capability value of the user terminals; M represents a total amount of computation required for the start node to complete the local joint processing; max {□} represents obtaining a maximum value; and
represents minimization.
A solution result of the pre-established optimization problem regarding computation time includes:
In an embodiment, the pre-established propagation step length computation formula is expressed as follows:
In an embodiment, for each start node, the forward propagation in the local joint processing of each start node includes:
In an embodiment, for each start node, the backward propagation in the local joint processing of each start node includes:
In an embodiment, the step of updating the current-time local model of the start node based on weighted gradients generated by the start nodes in the local joint processing, includes:
In an embodiment, the step of uploading, by each user terminal, the locally-updated model parameter for the current-round training to the server for aggregation, includes:
In an embodiment, the server is a base station in a cellular network, and each user terminal is a user terminal device in the cellular network.
In a second aspect, an embodiment of the disclosure provides a face recognition method based on adaptive split learning-federated learning, applied to a target terminal, the face recognition method includes the following steps:
The disclosure has at least the following beneficial effects.
According to the embodiments of the disclosure, in the process of training a face recognition model, the model training method retains the ability of FL to utilize distributed computing in the whole training process, so that the computation efficiency and convergence speed can be improved. The server allocates the respective propagation step length to each user terminal based on the device information of all user terminals, which realizes the allocation of computation loads to the user terminals according to the characteristics of different user terminals, so it can better adapt to the heterogeneity of the system, significantly alleviate the laggard effect and improve the training efficiency of the system. At the same time, it is difficult for eavesdroppers to recover data from the mixed model because each user terminal only transmits its own output layer gradient in the backward propagation, so the privacy protection performance of data can be enhanced.
The face recognition method provided by the embodiment of the disclosure is realized by using the face recognition model trained by the provided model training method, is suitable for various face recognition scenes, and has the advantage of high recognition accuracy.
The disclosure will be further described in detail with accompanying drawings and specific embodiments.
The following will provide a clear and complete description of the technical solutions in the embodiments of the disclosure, in conjunction with the accompanying drawings. Apparently, the described embodiments are only a part of the embodiments of the disclosure, not all of the embodiments. Based on the embodiments in the disclosure, all other embodiments obtained by those skilled in the art without creative labor fall within the scope of protection of the disclosure.
In the related art, in order to improve efficiency and security of a distributed learning system, O. Gupta et al. put forward split learning (SL), a core idea of which is to split a network structure, each device only keeps a part of the network structure (i.e., sub-network structure), and sub-network structures of all devices form a complete network model. In the training process, each device only performs forward or backward computation on a local network structure, and transmits a computation result to a next device. The devices complete the model training through an intermediate result of a joint network layer until the model converges. However, this scheme needs to transmit labels, and there is still the risk of data leakage. By integrating respective advantages of FL and SL, the embodiment of the disclosure proposes a method to solve a heterogeneity problem in training a face recognition model only by relying on a learning mechanism itself, which can be called RingSFL for short. Where Ring represents a ring topology and SFL represents SL+FL. The following is a specific explanation.
In the first aspect, an embodiment of the disclosure provides a model training method based on adaptive split learning-federated learning, which is applied to a RingSFL system including a server and multiple user terminals. Please refer to
S1, each user terminal uploads its device information to the server, and the server allocates a respective propagation step length and a respective aggregation weight to each user terminal based on the device information obtained from all user terminals.
Specifically, the propagation step length represents the number of propagation network layers. The aggregation weight is used for subsequent gradient weighting computation, and meanings and functions of the propagation step length and the aggregation weight are described in detail later.
In the embodiment of the disclosure, the server and the user terminals can establish the RingSFL system in advance in an agreed way. The RingSFL system includes multiple devices existing in a network. For example, in an alternative embodiment, the server can be a base station in a cellular network, and each user terminal can be a user terminal device in the cellular network. Of course, an applicable network form of the RingSFL system is not limited to the cellular network.
A face recognition task scene targeted by a trained model in the embodiment of the disclosure may be to use a trained face recognition model to perform operations related to identity confirmation for users in a specific area. For example, users can confirm their identity through face recognition, and then they can start related devices in a specific area to realize access control, punching cards, face-brushing shopping and so on. The specific area can be campus, community, or some confidential units and so on.
Each user terminal has a face image training set. The face image training set includes multiple face images as training samples and corresponding sample labels. Each sample label includes attribute information corresponding to the face in the corresponding training sample, such as location information, identity information, etc.
In an alternative embodiment, S1 may include S11 to S14.
S11, each user terminal uploads its computation capability value and the number of the training samples corresponding to its face image training set to the server.
Each user terminal knows its own computation capability value. For an i-th user terminal ui (1≤i≤N) in the RingSFL system, its computation capability value can be expressed as Ci, the higher the value of Ci, the stronger the computation capability of ui. The number of the training samples corresponding to the face image training set of ui can be expressed as Di.
The user terminals can upload their respective computation capability values and the numbers of the training samples corresponding to their own face image training sets to the server in parallel.
S12, the server calculates the propagation step length of each user terminal by using a pre-established propagation step length computation formula based on the obtained computation capability values of the user terminals.
Specifically, the propagation step length computation formula is determined according to a pre-established optimization problem regarding computation time.
In the embodiment of the disclosure, because the computation capability values of the user terminals are different, the server needs to allocate a computation load of each user terminal according to the computation capability values of the user terminals, so as to better adapt to the heterogeneity of the system. Therefore, before S1, the embodiment of the disclosure constructs the optimization problem regarding computation time in advance. The corresponding analysis process is as follows.
In order to minimize model training time and determine split points of the model, the embodiment of the disclosure designs a model split scheme in the RingSFL system. The computation load of each user terminal should be determined according to its computation capability.
A cutting ratio (i.e., split ratio) of the user terminal ui is defined as pi, which represents a computation load rate allocated to the user terminal ui. For all user terminals,
is used to represent a total computation capability value of all user terminals in the RingSFL system. ci is introduced to re-express the computation capability value of ui as ciC, where
represents a ratio of the computation capability value of ui to the total computation capability value.
M is used to represent a total amount of computation required for the user terminal as a start node to complete a local joint processing by using a batch of its own face image training set in the ring topology (please refer to the following for the local joint processing). A unit of the total amount of computation is giga floating point operations per second (GFLOPS), and the value is a definite value relative to known tasks and training data. The amount of computation of ui in this batch training is piMN, and then the computation time consumed to complete this batch training by ui is
The embodiment of the disclosure finds in the experiment that although some user terminals have high computation capability, they have become laggards in training because of the excessive amount of computation allocated to them. Therefore, the duration of system training batch is limited by the user terminal with the longest training time, and the computation load should be optimized to suppress the laggard effect, so as to minimize the training time.
Because there are the N numbers of user terminals in the system, the computation time consumed by the laggard to train a batch is
. In order to minimize the computation time of the laggard, the embodiment of the disclosure formulates the optimization problem regarding computation time.
The optimization problem regarding computation time includes:
represents the total computation capability value of all user terminals; ci represents the ratio of the computation capability value of ui to the total computation capability value; M represents the total amount of computation required for the start node to complete the local joint processing; max {□} represents obtaining the maximum value; and
represents minimization.
The optimization problem is solved by introducing a new variable m. The solution result of the optimization problem regarding computation time includes:
From the solution result of the above optimization problem, it can be seen that the optimal value of pi should be equal to the ratio ci of the computation capability value of ui to the total computation capability value. As the embodiment of the disclosure finds in the experiment that the laggard effect can be significantly alleviated through optimization of p i, Li=p*iw is set to reduce the training time. That is to say, the propagation step length computation formula is expressed as follows:
In the training process of the same model, the original network corresponding to each round global model adopts the same known network, so the total number w of layers of the original network is a known fixed value. For example, the original network can adopt any existing neural network for target classification, such as convolutional neural network (CNN), you only look once (YOLO) series or visual geometry group network (VGG16).
It should be noted that the optimal cutting ratio p*i of the user terminal ui may not be an integer, so p*iw may not be an integer in general. Therefore, the server may need to round the obtained non-integer Li, which can be rounded up or down.
S13, the server calculates the total number of the training samples of all user terminals, and determines a ratio of the number of training samples of each user terminal to the total number of training samples as the aggregation weight of the corresponding user terminal.
Specifically, the aggregation weight calculated by the server for the user terminal ui can be expressed as ai:
S14, the server sends the respective propagation step length and the respective aggregation weight to each user terminal.
In the same way, the server can send the respective propagation step length and the respective aggregation weight to each user terminal in parallel. At this point, the initialization process of propagation step length and aggregation weight required for training is completed. Because the computation capability value of each user terminal and the number of training samples corresponding to the face image training set of each user terminal are relatively fixed, the propagation step length and the aggregation weight of each user terminal remain unchanged in the subsequent training process.
Compared with the existing FL, the server of the embodiment of the disclosure adds the allocation step of propagation step length and aggregation weight to each user terminal, and the propagation step length of each user terminal is obtained by the server allocating the computation load according to the characteristics of different users' computation capability, so it can better adapt to the heterogeneity of the system, significantly alleviate the laggard effect and improve the training efficiency of the system.
S2, in a current-round training, each user terminal obtains a current-round global model from the server, and takes itself as a start node of the ring topology formed by all the user terminals to perform local joint processing of start nodes respectively corresponding to the user terminals for a preset number of times to obtain a locally-updated model parameter of each start node for the current-round training.
The local joint processing of each start node includes: performing forward propagation and backward propagation on a current-time local model of the start node in the current-round training based on a batch of the face image training set of the start node; and updating the current-time local model of the start node based on weighted gradients generated by all the start nodes in the local joint processing. The current-time local model corresponding to the first local joint processing of each start node in the current-round training is the current-round global model. Both forward propagation and backward propagation are completed by combining partial network trainings (also referred to as partial network training continuation) performed by the user terminals in the ring topology using their respective propagation step lengths. In the backward propagation, each user terminal uses its propagation step length and the aggregation weight corresponding to the start node to obtain the corresponding weighted gradient and transmit its output layer gradient.
At the beginning of each round training, all user terminals obtain the same current-round global model from the server and set it as their first local model in the current-round training. With the progress of each local joint processing, the “current-time local model” in this round training is constantly updated. For the first round training, the current-round global model is the original network, such as VGG16 mentioned above.
For each user terminal, the number of training samples in its own face image training set is often large, so it is unrealistic to input all the training samples into the model at one time in each round training. In order to improve the training effect, a current general model training method of ML is to divide the training set into multiple batches, and the batches are input into the model in turn for training. Therefore, in each round training, each user terminal performs local joint processing for the preset number of times, each local joint processing is completed by using an unused batch in the face image training set corresponding to each user terminal, and the local joint processing of the user terminals is carried out synchronously. In the embodiment of the disclosure, the specific value of the preset number is determined in advance according to the number and batch size of the training samples in the face image training set.
For each user terminal, each local joint processing of the user terminal can be divided into three stages: forward propagation, backward propagation and parameter update. The local joint processing of the user terminal is completed by all user terminals, and the parameter update stage needs to use the results of the local joint processing of all user terminals.
In the following, taking the local joint processing conducted by the user terminal as the start node of the ring topology in a round of training as an example, the three stages of forward propagation, backward propagation and parameter update are explained respectively.
For each start node, in the local joint processing of the start node, the process of the forward propagation includes the following steps.
The start node uses a current batch of the face image training set to propagate forward at least one layer corresponding to the propagation step length of the start node from a first layer of the current-time local model, and transmits a feature map output by a local network corresponding to its forward propagation and a serial number of its output layer (i.e., output layer serial number) to a next user terminal along the forward direction of the ring topology starting from the start node.
For each forward current node that traverses sequentially along the forward direction of the ring topology, the forward current node takes a next layer corresponding to an output layer serial number of a previous user terminal along the forward direction of the network as the start layer of the forward current node, propagates forward at least one layer corresponding to the propagation step length of the forward current node by using a computation result transmitted by a previous user terminal from the start layer of the forward current node, and transmits a computation result obtained by a local network corresponding to its forward propagation to a next user terminal along the forward direction of the ring topology. The forward current node is the user terminal traversed except the start node in the ring topology, and the end node is the last user terminal traversed in the ring topology. In addition to the end node, each forward current node also transmits its own output layer serial number to the next user terminal, and the computation result is the feature map output by the corresponding local network. The computation result of the end node is the face recognition result.
The start node compares the face recognition result transmitted by the end node with a sample label in the current batch of the face image training set to obtain a comparison result, and calculates a network loss value corresponding to the start node according to the comparison result.
For convenience of understanding, the embodiment of the disclosure takes a ring topology formed by three user terminals as an example. Please refer to
There are three user terminals u1, u2, u3 in
Specifically, u1 uses a current batch of the face image training set to propagate L1 layers forward in the forward direction of the network starting from the first layer of the current-time local model. Since the total number of layers of the current-time local model is the same as that of the original network, that is w, the input layer of the propagated local network by u1 is the first layer of the current-time local model, and the output layer is the L1-th layer of the current-time local model. In the process of the forward propagation of u1, the current batch of the face image training set is input to the input layer of the local network, and its output layer outputs a feature map ƒ1, and u1 transmits the feature map and its output layer serial number (ƒ1, L1) to u2. For the first local joint processing in the first round training, the current-time local model is the original network.
The input layer of the local network of u2 is the (Li+1)-th layer of the current-time local model, and u2 inputs ƒ1 into its input layer and propagates the L2 layers forward along the forward direction of the network. The output layer of the local network of u2 is the (L1+L2)-th layer of the current-time local model, the output layer of the local network of u2 outputs a feature map ƒ2, and u2 transmits (ƒ2, L1+L2) to u3.
The input layer of the local network of u3 is the (L1+L2+1)-th layer of the current-time local model. u3 inputs ƒ2 into its input layer and propagates the L3 layers forward along the forward direction of the network. The output layer of the local network of u3 is the (L1+L2+L3)-th layer of the current-time local model. In practice, due to the large number of user terminals, the output layer of the end node of the ring topology is the output layer of the current-time local model. The output layer of the local network of u3 outputs the face recognition result ƒ3, which will be transmitted to u1.
u1 compares the face recognition result f, with the sample label in the current batch. It is understandable that the face recognition result represents a predicted value, and the sample label records a true value. The network loss value can be calculated from the difference between the predicted value and the true value, that is, the corresponding loss1 of u1 can be obtained. The process of calculating the loss value can be understood by referring to the existing neural network training process, and will not be explained in detail here.
Thus, the forward propagation with u1 as the start node of the ring topology is completed.
In the process of the forward propagation of u1 with itself as the start node of the ring topology, u2 and u3 also perform forward propagation with themselves as the start nodes of the ring topology to obtain loss2 and loss3 respectively. The specific process of the forward propagation of u2 and u3 is similar to that of the forward propagation of u1 with itself as the start node of the ring topology. For the forward propagation carried out by multiple user terminals in parallel, please refer to
In
For each start node, in the local joint processing of the start node, the process of backward propagation includes the following steps.
Each start node transmits its network loss value and its own aggregation weight to the corresponding end node.
The end node uses the network loss value to backward propagate at least one layer corresponding to the propagation step length of the end node from the last layer of the current-time local model, calculates the local network gradient corresponding to the backward propagation of the end node, multiplies the obtained local network gradient with the aggregation weight of the start node to obtain the weighted gradient corresponding to the end node and stores it; and the end node transmits the calculated output layer gradient and its output layer serial number of the local network corresponding to the end node to the next user terminal along the backward direction of the ring topology.
For each backward current node that traverses sequentially along the backward direction of the ring topology, the backward current node takes the next layer corresponding to the output layer serial number of the previous user terminal along the backward direction of the network as the start layer of the backward current node, uses the output layer gradient transmitted by the previous user terminal to backward propagate at least one layer corresponding to the propagation step length of the backward current node from the start layer of the backward current node, calculates the local network gradient corresponding to its backward propagation, and multiplies the obtained local network gradient with the aggregation weight of the start node to obtain the weighted gradient corresponding to the backward current node and stores the weighted gradient corresponding to the backward current node, and transmits the calculated output layer gradient and its output layer serial number of the local network corresponding to the backward current node to the next user terminal along the backward direction of the ring topology. Each backward current node is the user terminal traversed except the start node and the end node in the ring topology. The start node takes the next layer corresponding to the output layer serial number of the previous user terminal along the backward direction of the network as the start layer of the start node, uses the output layer gradient transmitted by the previous user terminal to backward propagate at least one layer corresponding to the propagation step length of the start node from the start layer of the start node, calculates the local network gradient corresponding to the backward propagation, multiplies the obtained local network gradient with the aggregation weight of the start node to obtain the weighted gradient corresponding to the start node and stores it.
For the convenience of understanding, the disclosure will continue to be described by taking the ring topology of
The start node u1 transmits loss1 and a1 to the end node u3.
u3 uses loss1 to propagate the L3 layers from the last layer of the current-time local model used for the forward propagation in the backward direction of the network. Since the total number of layers of the current-time local model is w, the input layer of the local network propagated by u3 at this time is the last layer of the current-time local model, and the output layer is the (w−L3+1)-th layer of the current-time local model. In the backward propagation process of u3, the gradient of the propagated L3 layers is calculated to obtain the local network gradient G3,1 of u3, where the first subscript 3 of G3,1 represents the identification number of the current user terminal traversed, and the second subscript 1 of G3,1 represents that the current local joint processing is performed by u1 as the start node, and the product (G3,1×a1) of the calculated local network gradient G3,1 and a1 is used as the weighted gradient of u3, which is stored for subsequent parameter update. At the same time, the output layer gradient g3 is calculated and the output layer gradient together with the output layer serial number (g3, w−L3+1) of u3 are sent to u2.
The input layer of the local network of u2 is the (w−L3)-th layer of the current-time local model, u2 uses g3 to propagate the L2 layers from its input layer in the backward direction of the network, the output layer of the local network of u2 is the (w−L3−L2+1)-th layer of the current-time local model. In the process of backward propagation of u2, the gradient of the propagated L2 layers is calculated to obtain the local network gradient G2,1 of u2, and the product (G2,1×a1) of the calculated local network gradient G2,1 and a1 is used as the weighted gradient of u2, which is stored for subsequent parameter update. At the same time, the output layer gradient g2 of u2 is calculated and the output layer gradient together with the output layer serial number (g2, w−L3−L2+1) of u2 are sent to u1.
The input layer of the local network of u1 is the (w−L3−L2)-th layer of the current-time local model, u1 uses g2 to propagate the L1 layers in the backward direction of the network from its input layer, and the output layer of the local network of u1 is the (w−L3−L2−L1+1)-th layer of the current-time local model. In practice, due to the large number of user terminals, the output layer of the local network of the start node of the ring topology is the first layer of the current-time local model. In the process of backward propagation of u1, the gradient of the propagated L1 layers is calculated to obtain the local network gradient G1,1 of u1 and the product (G1,1×a1) of the calculated local network gradient G1,1 and a1 is used as the weighted gradient of u1, which is stored for subsequent parameter update. For the last user terminal in the backward propagation, there is no need to calculate and transmit the output layer gradient of u1 at this time.
Thus, the corresponding backward propagation of u1 is completed. It can be understood that after the forward propagation and the backward propagation of u1 with itself as the start node of the ring topology, each user terminal is stored with the corresponding weighted gradient, that is, in the corresponding backward propagation of u1, the weighted gradient of u3 is (G3,1×a1), the weighted gradient of u2 is (G2,1×a1), and the weighted gradient of u1 is (G1,1×a1).
In the process of the forward propagation and the backward propagation of u1 with itself as the start node of the ring topology, u2 and u3 also perform the forward propagation and the backward propagation with themselves as the start nodes of the ring topology. After the backward propagation, weighted gradients are also stored, and the specific process of the backward propagation of u2 and u3 is similar to that of the backward propagation of u1. It can be understood that in the corresponding backward propagation process of u2, the weighted gradient of u1 is (G1,2×a2), the weighted gradient of u3 is (G3,2×a2), and the weighted gradient of u2 is (G2,2×a2). In the corresponding backward propagation process of u3, the weighted gradient of u2 is (G2,3×a1), the weighted gradient of u1 is (G1,3×a3), and the weighted gradient of u3 is (G3,3×a3).
The current-time local model of each start node is updated based on the weighted gradients generated by all the start nodes in the local joint processing process, which includes the following steps.
1) A sum of the weighted gradients corresponding to each user terminal (i.e., each start node) is calculated from all the weighted gradients generated in the local joint processing performed by all the start nodes.
Continuing to explain from the previous example, after u1, u2, u3 are respectively used as the start nodes of the ring topology for the forward propagation and the backward propagation, the weighted gradients obtained are as shown above. Then, by summing the weighted gradients stored at each user terminal during different backward propagation processes, the sum of the weighted gradients of each user terminal can be obtained.
Thus, the sum of the weighted gradient corresponding to u1 is (G1,1×a1)+(G1,2×a2)+(G1,3×a3); the sum of the weighted gradient corresponding to u2 is (G2,1×a1)+(G2,2×a2)+(G2,3×a3); and the sum of the weighted gradient corresponding to u3 is (G3,1×a1)+(G3,2×a2)+(G3,3×a3).
2) For each user terminal, a product of the sum of the weighted gradients corresponding to the user terminal and a preset learning rate is calculated, and the product is subtracted from the parameter of the current-time local model of the user terminal to obtain the updated current-time local model of the user terminal, and the updated current-time local model is used for the next local joint processing when the current-time local joint processing does not correspond to the preset number of times (that is, the number of the performed local joint processing does not reach the preset number).
In the embodiment of the disclosure, the preset learning rate is a preset numerical value, which can be expressed by η, and the value of η can be 0.1 and so on.
For each user terminal, the difference between the parameter of the current-time local model of the user terminal and the product of the sum of the weighted gradients corresponding to the user terminal and the preset learning rate is the parameter of the updated current-time local model. After setting it into the current-time local model, the updated current-time local model of the user terminal can be obtained.
After the local joint processing is finished for each user terminal, it is determined whether the total number of local joint processing completed at present reaches the preset number, and in response to the total number of local joint processing completed at present fails to reach the preset number, the updated current-time local model is used as the “current-time local model” for the next local joint processing to continue the local joint processing. In response to the total number of local joint processing completed at present reaching the preset number, the current-round training of the user terminal is ended, and the parameter of the updated current-time local model of the user terminal is uploaded to the server as the local updated model parameter for the current-round training.
S3, each user terminal uploads its own locally-updated model parameter for the current-round training to the server for aggregation, thereby to obtain the updated global model for the current-round.
This step is to complete model aggregation, which refers to the process that the server receives the training results uploaded by the user terminals and aggregates them. Among them, the model parameters uploaded by the user terminals essentially represent mixed models.
In an alternative embodiment, each user terminal uploads its own locally-updated model parameter for the current-round training to the server for aggregation, including the following steps: each user terminal uploads its own locally-updated model parameter for the current-round training to the server, and the server calculates the average value of the received locally-updated model parameters as the current-round updated global model parameter.
After the round of training, the user terminals can upload their locally-updated model parameters for the current-round training to the server in parallel. The serve can obtain the current-round updated global model parameter and set the current-round updated global model parameter into the current-round global model to obtain the current-round updated global model.
In order to illustrate the effectiveness of the polymerization scheme of the embodiment of the disclosure, it is analyzed and expounded below. Because of the existence of the mixed model, the traditional federated averaging (FedAvg) algorithm of the model aggregation cannot be used directly. Therefore, the embodiment of the disclosure provides a revised model aggregation scheme in the RingSFL system.
Different from FedAvg, the weighting in the RingSFL system is realized by each user terminal during training. The aggregation weight of ui is transmitted between the user terminals with the backward propagation, and multiplied by the calculated local network gradient, which is the weighted gradient stored by each user terminal.
W(r)=[L1(r), . . . , Lk(r), . . . , LW(r))] represents a current-round global model of the r-th round training for each user terminal, where Lk(r) represents a k-th layer of W(r). Because more than one user terminal can train each layer model of ui, the gradients of the user terminals can be accumulated in each layer model of ui. Ui,k represents a user terminal set at the k-th layer of the trained model of ui,
where gi,k represents the gradient of the k-th layer calculated by using the training samples of ui. ai represents the aggregation weight of ui, and η represents a preset learning rate. The model obtained by server aggregation can be expressed as follows:
Analysis of the above formula is as follows:
It can be seen that the weighting of the gradients in the embodiment of the disclosure is realized by the user terminal during the backward propagation, and the server only needs to average the locally-updated model parameters uploaded by the user terminals for the current-round training.
It should be noted that the aggregation scheme of the embodiment of the disclosure will reduce the learning rate. Therefore, in order to compare the aggregation effect with the existing algorithm, such as FedAvg, at the same learning rate, it is necessary to expand the learning rate of the embodiment of the disclosure by a certain multiple to ensure the convergence performance. For example, compared with the learning rate of the existing algorithm of 0.01, the embodiment of the disclosure can set the learning rate to 0.1, or in an alternative embodiment, the learning rate can be set as the product of the traditional learning rate and the number of participating user terminals.
Through verification, under the same learning rate, the polymerization result of the embodiment of the disclosure is similar to that of FedAvg, and the polymerization effect is basically the same.
S4, the server determines whether the current-round updated global model meets the convergence condition.
Specifically, the server can input the test samples stored by itself in the face image test set into the current-round updated global model to obtain the corresponding prediction result of face recognition. When the difference between the prediction result and the sample label of the input sample is less than a certain threshold, it can be determined that the current-round updated global model meets the convergence condition.
When the difference between the prediction result and the sample label of the input sample is not less than the certain threshold, S5 is executed, that is, the current-round updated global model is taken as a next-round global model, and the step of obtaining the current-round global model from the server by each user terminal in the current-round training is returned to perform.
Specifically, when the current-round updated global model does not meet the convergence condition, it returns to S2 for the next-round training, and the “next-round global model” obtained in S5 will be the “current-round global model” in S2 after returning to S2 to start the next-round training.
When the global model meets the convergence condition, S6 is executed, that is, the current-round updated global model is determined as the trained face recognition model.
Specifically, when the current-round updated global model meets the convergence condition, the training is ended and the current-round updated global model is determined as the trained face recognition model. Furthermore, the server can also send the trained face recognition model to the required user terminal.
In the process of face recognition model training, the model training method based on adaptive split learning-federated learning provided by the embodiment of the disclosure retains the ability of FL to utilize distributed computing in the whole training process, so that the computation efficiency and convergence speed can be improved. The server allocates the respective propagation step length to each user terminal based on the device information of all user terminals, which realizes the allocation of computation loads to the user terminals according to the characteristics of different user terminals, so it can better adapt to the heterogeneity of the system, significantly alleviate the laggard effect and improve the training efficiency of the system. At the same time, it is difficult for eavesdroppers to recover the data from the mixed model because each user terminal transmits only its own output layer gradient in the backward propagation, so the performance of data privacy protection can be enhanced.
Because the user's face image data involves identity characteristics and belongs to personal privacy, it is difficult for the server to grasp the personal face image data of all users. The embodiment of the disclosure can realize joint training by using local data by using the model training method based on adaptive split learning-federated learning, so that privacy can be ensured. When there are new users in a specific area, the user terminal can reuse the new sample data to update the model based on the original model obtained by training, which can greatly facilitate the face recognition task in a specific field and has a wide application prospect.
In order to verify the effect of the model training method based on adaptive split learning-federated learning in the embodiment of the disclosure, the experimental results are described below.
1) For eavesdroppers, it is very difficult to recover user data from any partial or fragmented model, because it is necessary to obtain complete model parameters or gradients at present. In the method of the embodiment of the disclosure, the communication between the user terminals only includes the last layer output of each user terminal, and the model transmitted from each user terminal to the server is a mixed model. Because the eavesdropper does not know the cutting ratio, it is difficult for the eavesdropper to obtain the complete model parameters by eavesdropping. The only possible situation to recover user data by eavesdropping is when each user terminal only trains one network layer. By eavesdropping on all communication links in the system, it is possible for eavesdroppers to obtain the gradient of each layer in the model and piece together a complete model. After that, the embodiment of the disclosure verifies the possibility of privacy leakage in this case. The probability that the communication link between ui and ui-1 is eavesdropped is defined as e′i, and the probability that the communication link between u1 and the server is eavesdropped is defined as ei. Then, the privacy leakage probability of the user terminal ui can be expressed as:
The influence of different eavesdropping probability and the number of user terminals on the privacy leakage probability is shown in
2) For the malicious server, because the cutting rate of the model is known by the server, it is slightly less difficult for the server to recover the user terminal data than the eavesdropper. User terminals with high cutting rate usually have overlapping layers during training, and it is very challenging to recover the gradient of a single user terminal from the overlapping layers. However, there are no overlapping layers in some special cutting ratio settings, so it is possible for the server to recover data from the uploaded model. In order to solve this problem, the user terminal can negotiate with the server, and artificially make the system have overlapping layers to ensure security by appropriately shifting the cutting point.
The convergence performance of the embodiment of the disclosure can be further illustrated by simulation experiments.
The experiment is carried out on four graphics processing units (GPU) (GEFORCE RTX 3090 24G) on a server. Each GPU simulates a user terminal participating in the training, while the server is simulated by a CPU (Intel® Xeon® Silver 4214R CPU @ 2.40 GHz). The software environment used in the experiment is Python 3.7.6. PyTorch 1.8.1+cu111 is used for model building and model training. In the experiment, CIFAR10 data set is used to simulate the training effect of the face image data set, and the original network used for training is the VGG16 model. In addition, in order to illustrate the influence of the distribution of data sets on RingSFL, all experiments are carried out on IID data sets and non-IID data sets respectively. In the experiment, the FL and SL algorithms are considered for comparison.
This experiment compares the convergence performance of RingSFL with FL and SL and the influence of communication between the user terminals.
Please refer to
Please refer to
Please refer to
To sum up, compared with the traditional FL, the model training method based on adaptive split learning-federated learning (RingSFL for short) proposed by the embodiment of the disclosure can be used to improve the security of the distributed learning system and achieve faster convergence without sacrificing the accuracy of the model. In addition, RingSFL can also be applied to scenes with obvious system heterogeneity to improve the overall system efficiency. Therefore, it can be effectively used to train neural network models such as face recognition models.
In the second aspect, an embodiment of the disclosure provides a face recognition method based on adaptive split learning-federated learning, which is applied to a target terminal, as shown in
S01, a trained face recognition model and an image to be recognized are obtained.
Specifically, the face recognition model is trained according to the model training method based on adaptive split learning-federated learning in the first aspect. For the specific content of the model training method based on adaptive split learning-federated learning, please refer to the related description of the first aspect, and the description will not be repeated here.
The target terminal is the service or any user terminal in the RingSFL system. Or in an alternative embodiment, the target terminal can also be a trust server or a trust user terminal outside the RingSFL system. The so-called “trust” means that all parties have certain confidentiality agreements and there are no security risks such as privacy disclosure. Alternatively, the target terminal can also be a related device in a specific area corresponding to the RingSFL system, such as an access control device, a punching device, a face-brushing shopping device, etc.
The image to be recognized can be an image acquired or shot by the target terminal, which may contain a human face.
S02, the image to be recognized is input into the face recognition model to obtain a face recognition result.
The face recognition result includes attribute information of a human face in the image to be recognized. The attribute information includes identity information. For example, the identity information can include name, gender, age, ID number, job number, student number, etc., or it can also include the user's financial account information, etc.
The face recognition method based on adaptive split learning-federated learning provided by the embodiment of the disclosure is realized by using the face recognition model trained by the provided model training method based on adaptive split learning-federated learning, and the training process retains the ability of FL to utilize distributed computing, so that the convergence speed can be improved. It can better adapt to the heterogeneity of the system, improve the training efficiency of the system and enhance the privacy protection performance of data. This face recognition method is suitable for various face recognition scenes and has the advantage of high recognition accuracy.
Number | Date | Country | Kind |
---|---|---|---|
2022103453239 | Apr 2022 | CN | national |
The application a continuation of International Patent Application No. PCT/CN 2023/081800, filed on Mar. 16, 2023, which claims the priority of Chinese Patent Application No. 202210345323.9, filed on Apr. 2, 2022, both of which are herein incorporated by reference in their entirety.
Number | Date | Country | |
---|---|---|---|
Parent | PCT/CN2023/081800 | Mar 2023 | WO |
Child | 18737953 | US |