The subject matter disclosed herein relates to deep learning techniques and, more particularly, to systems and methods for deep learning techniques utilizing continuous federated learning with a distributed selective local re-tuning process.
Deep learning models have been proven successful in addressing problems involving sufficiently large, balanced and labeled datasets that appear in computer vision, speech processing, image processing, and other problems. Ideally, it is desired that these models continuously learn and adapt with new data, but this remains a challenge for neural network models since most of these models are trained with static large batches of data. Retraining with incremental data generally leads to catastrophic forgetting (i.e. training a model with new information interferes with previously learned knowledge).
Ideally, artificial intelligence (AI) learning systems should adapt and learn continuously with new knowledge while refining existing knowledge. Current AI learning schemes assume that all samples are available during the training phase and, therefore, requires retraining of the network parameters on the entire dataset in order to adapt to changes in the data distribution. Although retraining from scratch pragmatically addresses catastrophic forgetting, in many practical scenarios, data privacy concerns do not allow for sharing of training data. In those cases, retraining with incremental new data can lead to significant loss of accuracy (catastrophic forgetting).
In accordance with an embodiment of the present technique, a deep learning-based continuous federated learning network system is provided. The system includes a global site comprising a global model; and a plurality of local sites having a respective local model derived from the global model and a plurality of model tuning modules. Each of the plurality of model includes a processing system programmed to receive incremental data and select one or more layers of the local model for tuning based on the incremental data. The selected layers in the local model are finally tuned to generate a retrained model.
In accordance with another embodiment of the present technique, a method is provided. The method includes receiving, at a plurality of local sites, a global model from a global site and deriving a local model from the global model at each of the plurality of local sites. The method further includes tuning the respective local model at the plurality of local sites. For tuning the respective local model, incremental data is received from the local sites and one or more layers of the local model are selected for tuning based on the incremental data. Based on the tuning of the selected layers in the local model a retrained model is generated.
In accordance with yet another embodiment of the present technique, a non-transient, computer-readable medium storing instructions to be executed by a processor to perform a method is provided. The method includes receiving, at a plurality of local sites, a global model from a global site and deriving a local model from the global model at each of the plurality of local sites. The method further includes tuning the respective local model at the plurality of local sites by receiving incremental data and selecting one or more layers of the local model for tuning based on the incremental data. A retrained model is generated based on tuning of the selected layers in the local model.
These and other features, aspects, and advantages of the present invention will become better understood when the following detailed description is read with reference to the accompanying drawings in which like characters represent like parts throughout the drawings, wherein:
One or more specific embodiments will be described below. In an effort to provide a concise description of these embodiments, not all features of an actual implementation are described in the specification. It should be appreciated that in the development of any such actual implementation, as in any engineering or design project, numerous implementation-specific decisions must be made to achieve the developers' specific goals, such as compliance with system-related and business-related constraints, which may vary from one implementation to another. Moreover, it should be appreciated that such a development effort might be complex and time consuming, but would nevertheless be a routine undertaking of design, fabrication, and manufacture for those of ordinary skill having the benefit of this disclosure.
When introducing elements of various embodiments of the present invention, the articles “a,” “an,” “the,” and “said” are intended to mean that there are one or more of the elements. The terms “comprising,” “including,” and “having” are intended to be inclusive and mean that there may be additional elements other than the listed elements. Furthermore, any numerical examples in the following discussion are intended to be non-limiting, and thus additional numerical values, ranges, and percentages are within the scope of the disclosed embodiments.
Some generalized information is provided to provide both general context for aspects of the present disclosure and to facilitate understanding and explanation of certain of the technical concepts described herein.
In deep learning (DL), a computer model learns to perform classification tasks directly from images, text or sound. Deep neural networks combine feature representation learning and classifiers in a unified framework. It is noted that the term “deep” typically refers to the number of hidden layers in the neural network. Traditional neural networks only contain 2-3 hidden layers, while deep networks can have as many as 150. Models are trained by using a large set of labeled data and neural network architecture that contains many layers, where the model learns features directly from the data without the need for manual feature extraction. The neural network is organized in layers consisting of a set of interconnected nodes. Output from a layer represent features that may have data values associated therewith. As a non-exhaustive example, a feature may be a combination of shape, color, appearance, texture, aspect ratio, etc.
A convolutional neural network (CNN) is a process used in deep learning, where the CNN may find patterns in data. They learn directly from the data, using patterns to classify items, eliminating the need for manual feature extraction. The CNN may have tens or hundreds of layers that learn to detect different features in an image, text, sound, etc. for example. Like other neural networks, the CNN is composed of an input layer, and output layer and many hidden layers in between. These hidden layers perform operations that alter the data with the intent of learning features specific to the data. An example of a layer is a convolutional layer, which puts the input data through a set of convolutional filters, each of which activates certain features form the images. The filters are applied to each training data at different resolutions, for example, and the output of each convolved data is used as the input to the next layer. These operations are repeated over tens or hundreds of layers, with each layer learning to identify different features. After learning the features in many layers, the CNN shifts to classification, and the classification output can be provided.
It would be desirable for the models to continuously learn and adapt with new data, but this is a challenge for standard neural network models. This is a particular challenge with respect to healthcare or in-flight monitoring, where there is limited data, diversity in sample distribution and limited or no access to training data. Transfer learning is a conventional framework to retrain models given new incoming data, but these set of models suffer from catastrophic forgetting. As will be known to the one skilled in the art, catastrophic forgetting is when a model is trained with new information and this interferes with the previously learned knowledge. With catastrophic forgetting, the model “forgets’ what it had learned before and retunes the model only to the incoming data. As such, the model is only being trained on the new information, so it is learning on a much smaller scale. Catastrophic loss of previously learned responses whenever an attempt is made to train the network with a single new (additional) response is particularly undesirable.
Standard models are typically trained with static large batches of data. The conventional models assume that all samples are available during the training phase and, therefore requires retraining of the network parameters on the entire dataset in order to adapt to changes in the data distribution. Although retraining from scratch pragmatically addresses catastrophic forgetting, this process is very inefficient and hinders the learning of novel data in real time. Further, in many practical scenarios, data privacy concerns do not allow for sharing of training data. In those cases, retraining with incremental new data may lead to significant loss of accuracy (catastrophic forgetting).
Additionally, standard DL models may be trained on centralized training data. Performance of a DL model may be adversely affected by site-specific variables like machine make, software versions, patient demographics, and site-specific clinical preferences. Continuous federated learning enables incremental site-specific tuning of the global model to create local versions. In a continuous federated learning scenario, a global model is deployed across multiple sites that cannot export data. Site specific ground truth is generated using auto-curation models that may use segmentation, registration machine learning and/or deep learning models. Such ground truth may have to be refined depending on local preferences of the expert.
Conventionally, a model may be retrained based on the last layers of the model. The decision on which layers to retrain is typically done in an iterative fashion, which is time-consuming and may not lead to a unique solution.
To address these concerns, one or more embodiments provide a data generation framework having a model tuning module that trains a local model. New incoming incremental data received by the model tuning module may affect some aspects of the local model, and not other aspects. As such, the model tuning module may retrain the layers of the local model affected by the new incremental data, instead of retraining the entire model. Continuing with the orange example above, for the shape feature, if with the new model, the shape is similar to what you′d expect to see (e.g., round), then this layer does not need to be retrained. If, however, the image of the orange shows an ellipse due to distortion from a new camera, this layer may need to be retrained (i.e., weights associated with shape in this layer may be adjusted) to be able to identify the shape of an orange.
The model tuning module may determine which nodes are useful to be retrained and which nodes should be retained and not retrained. The model tuning module may partially retrain the model for new incoming incremental data while maintaining performance on the previously trained task/data by determining which layers to retrain or “tune” by analyzing model features and then inferring which layer of the model to tune. The layer determination may be based on feature values which inform layer weights. One or more embodiments may provide for local learning and faster adaptation to new data without catastrophic forgetting, while also providing for retraining in scenarios where the training data cannot be shared.
With the preceding in mind, and by way of providing useful context,
In the continuous federated learning scenario 10, the global model 16 is deployed across multiple sites 14 that cannot export data. A site-specific ground truth is generated using auto-curation models that may use segmentation, registration machine learning, and/or deep learning models. The site-specific ground truth may have to be refined depending on local preferences of the expert. An automatically generated and refined ground truth is then further used for local training of the models. Selective local updates of the weights of the global model 16 creates a local mutant 18 of the global model 16. The weights of the local models 18 are then encrypted and sent to the central server for selective updating of the global model 16 as indicated by block 20. These local updates or site-specific preferences (e.g., weights) from the local sites 14 are combined when updating the global model 16 at the global site 12. The global model update would be strategic and would be dependent on domain and industry specific requirements.
Initially, at step 110, a global model is received at a plurality of local sites. In one embodiment, the global model is a gold test trained model as will be explained below with respect to
Next, at step 114, incremental data is received at a model tuning module located at the local site. The incremental data is received from the respective local sites and is different than the gold test data. The incremental data may include for example, more pediatric brain scans compared to the global data. At step 116, one or more layers of the local model are selected for tuning based on the incremental data. In general, based on the incremental data, few epochs are run on the local model. Thereafter, the model tuning model compares a first output of a layer of the local model is with a second output of the corresponding layer of the global model. Based on the variance between the first output and the second output, the model tuning model selects the layers of the local model that needs tuning. For example, if the variance between the first output and the second output of a particular layer exceeds a threshold, then the model tuning module will select that layer for tuning otherwise model tuning module will freeze that layer i.e., the layer will not be changed or retuned to a different value.
In other words, to determine which layers should be tuned, the model tuning module may compare the trained feature distribution (i.e., second output) for each layer of the global model with the feature distribution (i.e., first output) for the corresponding layer output from the local model. Layers output from the local model with a feature distribution close to the trained feature distribution (determined from feature statistics like mean, variance etc.) for a corresponding layer may be frozen or not trained, while layers output from the local model with a feature distribution that is not close to the trained feature distribution may cause the corresponding layer for the trained feature distribution to be selected by the model tuning module for training. In one or more embodiments, to determine whether the feature distribution in a layer output from the local model is close to the trained feature distribution for a corresponding layer, the difference between the distributions may be compared to a user-defined threshold value.
Finally, at step 118, the selected layers of the local model are tuned to generate a retrained model. In one or more embodiments, tuning a layer includes adjusting the weight associated with nodes in the layer. The adjusted weights replace their corresponding weights in the global model to become a new retrained model. The retrained model is then tested on the incremental data to ensure model accuracy. In one or more embodiments, the application of the incremental data to the retrained model may confirm the retraining was accurate within a given value or range of values (e.g., the model is expected to operate with 95% accuracy). However, if the retrained model is not accurate enough then the weights of the model are further changed to update the retrained model to meet the accuracy requirement. Finally, the weights of the retrained model of all the local sites are then combined, encrypted and sent to the central server for further training the global model as explained with respect to
In general, the convolution neural network 202 consists of the input layer 208, hidden layer 212, 214 and the output layer 216. In one or more embodiments, the nodes in a layer is weighted based on the importance of that variable/node to the task. During training the weight of the nodes are optimized in the hidden layers (212, 214). The training process ensures optimization of the weights. In one or more embodiments convolutional neural networks are not dependent on the model architecture, and may be generalizable to more or less layers. It is noted that while the embodiments and non-exhaustive examples included herein may be described in terms of multiple layers, as deep learning typically describes architecture with multiple layers.
In one or more embodiments, once the convolution neural network has been trained, the weights of the nodes in the initial input layer 208, any intervening layers (e.g., second layer 212, third layer 214, etc.), and the output layer 216 are optimized to output a trained model 202. In one or more embodiments, the trained model 202 may then be executed with a set of gold test data 218 to confirm the accuracy of the trained model 202. Gold test data 218 is data that has been verified by a suitable party. When the output of the execution of the trained model 202 with the gold test data 218 matches an expected output within a given threshold, the trained model 202 may classified as the global model 220.
It should be noted that the DL model that was simulated herein included hidden layers and an output layer with three nodes (each for one class). The DL model was trained with 50% of the first data distribution for 100 epochs to assign labels. The class labeling accuracy for this DL model was 91% with validation data and the same was 92% with 50% hold out test data for the first data distribution. As will be appreciated by those skilled in the art, during training, the data is divided into training and validation dataset. The trained model performance is validated using the validation dataset and another dataset is completely held back from the training and the validation process. This dataset that is held back is called hold out test dataset and is held back to test the generalizability of the validated model in a new dataset. The performance for the DL model dropped to 80% for the second data distribution when validated with 50% samples in hold-out test dataset. Thus, it can be seen that the original DL model accuracy drops when it is tested on the new data. Therefore, the technique presented herein selects certain layers of the original DL model for tuning, updates their weights according to the second data distribution and thus, can improve performance of the DL model for data classification.
One of the advantages of the present technique is that retraining only particular layers of the DL model based on features may ensure that the model quickly adapts to local data. This is specifically applicable in scenarios where 1) There is a need to adapt to site-specific customization and 2) Local data may not be shared with a global source to allow retraining of a global model with all new and old data, as described further below.
The techniques presented and claimed herein are referenced and applied to material objects and concrete examples of a practical nature that demonstrably improve the present technical field and, as such, are not abstract, intangible or purely theoretical. Further, if any claims appended to the end of this specification contain one or more elements designated as “means for [perform]ing [a function] . . . ” or “step for [perform]ing [a function] . . . ”, it is intended that such elements are to be interpreted under 35 U.S.C. 112(f). However, for any claims containing elements designated in any other manner, it is intended that such elements are not to be interpreted under 35 U.S.C. 112(f).
This written description uses examples to disclose the present subject matter, including the best mode, and also to enable any person skilled in the art to practice the subject matter, including making and using any devices or systems and performing any incorporated methods. The patentable scope of the subject matter is defined by the claims, and may include other examples that occur to those skilled in the art. Such other examples are intended to be within the scope of the claims if they have structural elements that do not differ from the literal language of the claims, or if they include equivalent structural elements with insubstantial differences from the literal languages of the claims.