Embodiments of the present disclosure generally relate to the field of computers, and more specifically to a method, device and computer program product for multi-source domain adaptation.
Artificial neural networks have produced great advances for many prediction tasks. Such success depends on the availability of a large amount of labeled training data under a standard supervised learning setting, but the labels are typically expensive and time-consuming to collect. Domain adaptation is a field associated with machine learning and transfer learning, and can reduce the labeling cost by exploiting existing labeled data in a source domain. The domain adaptation aims at transferring knowledge from the source domain to train a prediction model in a target domain.
Unsupervised domain adaptation (UDA) is a widely used domain adaptation setting, where data in the source domain is labeled while data in the target domain is unlabeled. Thus, UDA methods make predictions for the target domain while manual labels or annotations are only available in the source domain. Generally, UDA methods assume the source domain data comes from the same source and have the same distribution, and leverages features from a labeled source domain and train a classifier for an unlabeled target domain.
Embodiments of the present disclosure provide a method, device and computer program product for multi-source domain adaptation.
According to one aspect of the present disclosure, there is provided a computer-implemented method. The method comprises generating a first representation of a target image in a target data through a first classifier, generating a second representation of the target image through a second classifier, and generating a third representation of the target image through a third classifier. The first classifier is trained using a first source data and the target data, the second classifier is trained using a second source data and the target data, and the third classifier is trained using at least the first and second source data and the target data. During the training, a mutual learning is conducted among the first, second and third classifiers. That is, the third classifier and the first classifier learn from each other, while the third classifier and the second classifier also learn from each other. The first and second source data comprises labeled images, while the target data comprises unlabeled images. The method further comprises determining a label of the target image based on the first, second and third representations.
According to one aspect of the present disclosure, there is provided an electronic device. The electronic device comprises a processing unit and a memory coupled to the processing unit and storing instructions thereon. The instructions, when executed by the processing unit, perform acts comprising generating a first representation of a target image in a target data through a first classifier, generating a second representation of the target image through a second classifier, and generating a third representation of the target image through a third classifier. The first classifier is trained using a first source data and the target data, the second classifier is trained using a second source data and the target data, and the third classifier is trained using at least the first and second source data and the target data. During the training, a mutual learning is conducted among the first, second and third classifiers. That is, the third classifier and the first classifier learn from each other, while the third classifier and the second classifier also learn from each other. The first and second source data comprise labeled images, while the target data comprises unlabeled images. The acts further comprise determining a label of the target image based on the first, second and third representations.
According to one aspect of the present disclosure, there is provided a computer program product. The computer program product comprises executable instructions. The executable instructions, when executed on a device, cause the device to perform acts comprising generating a first representation of a target image in a target data through a first classifier, generating a second representation of the target image through a second classifier, and generating a third representation of the target image through a third classifier. The first classifier is trained using a first source data and the target data, the second classifier is trained using a second source data and the target data, and the third classifier is trained using at least the first and second source data and the target data. During the training, a mutual learning is conducted among the first, second and third classifiers. That is, the third classifier and the first classifier learn from each other, while the third classifier and the second classifier also learn from each other. The first and second source data comprise labeled images, while the target data comprises unlabeled images. The acts further comprise determining a label of the target image based on the first, second and third representations.
This Summary is provided to introduce a selection of concepts in a simplified form that are further described below in the Detailed Description. This Summary is not intended to identify key features or essential features of the claimed subject matter, nor is it intended to be used to limit the scope of the claimed subject matter.
The above and other features, advantages and aspects of embodiments of the present disclosure will be made more apparent by describing the present disclosure in more detail with reference to drawings. In the drawings, the same or like reference signs represent the same or like elements, wherein:
Embodiments of the present disclosure will be described in more detail below with reference to figures. Although the drawings show some embodiments of the present disclosure, it should be appreciated that the present disclosure may be implemented in many forms and the present disclosure should not be understood as being limited to embodiments illustrated herein. On the contrary, these embodiments are provided herein to enable more thorough and complete understanding of the present disclosure. It should be appreciated that drawings and embodiments of the present disclosure are only used for exemplary purposes and not used to limit the protection scope of the present disclosure.
As used herein, the term “comprise” and its variants are to be read as open terms that mean “comprise, but not limited to.” The term “based on” is to be read as “based at least in part on.” The term “an embodiment” is to be read as “at least one embodiment.” The term “another embodiment” is to be read as “at least one other embodiment.” The term “some embodiments” is to be read as “at least some embodiments.” Definitions of other terms will be given in the text below.
Traditional unsupervised domain adaptation (UDA) methods generally assume the setting of a single source domain, where all the labeled source data come from the same distribution. However, in practice, the labeled images may come from multiple source domains with different distributions. In such scenarios, the single source domain adaptation methods may fail due to the existence of domain shifts across different source domains. Some multi-source domain adaptation methods may support multiple source domains, but fail to consider the differences and domain shifts between different source domains.
To this end, a new mutual learning network for multi-source domain adaptation is proposed, which can improve the accuracy of label prediction for images. Consider that multiple source domains have different distributions, embodiments of the present disclosure build one adversarial adaptation subnetwork (referred to as “branch subnetwork”) for each source-target pair and a guidance adversarial adaptation subnetwork (referred to as “guidance subnetwork”) for the combined multi-source-target pair. In addition, multiple branch subnetworks are aligned with the guidance subnetwork to achieve mutual learning, and the branch subnetworks and the guidance subnetwork can learn from each other during the training and make similar predictions in the target domain. Such a mutual learning network is expected to gather domain specific information from each source domain through branch subnetworks and gather complementary common information through the guidance subnetwork, which can improve the information adaptation efficiency between multi-source domains and the target domain.
Reference is made below to
At 202, a first representation of a target image in a target data is generated by a first classier. For example, the first classier may be trained by using a pair of a first source domain and a target domain as an input. At 204, a second representation of the target image is generated by a second classifier. For example, the second classier may be trained by using a pair of a second source domain and the target domain as an input.
At 206, a third representation of the target image is generated by a third classifier. For example, the third classier is trained by using a pair of the combined first and second source domains and the target domain as an input. In addition, during the training, a mutual learning is conducted among the first, second and third classifiers. That is, the third classifier and the first classifier learn from each other during the training, and the third classifier and the second classifier also learn from each other during the training.
At 208, a label of the target image is determined based on the first, second and third representations. For example, after training the model, multiple classifiers may be obtained from the model, and may be used to predict the label of the unlabeled image in the target domain. The final prediction probability result may be calculated according to the predicted label probability vectors of all the branch subnetworks and the guidance subnetwork.
Since the multiple source domains have different distributions, embodiments of the present disclosure train one branch subnetwork to align each source domain with the target domain, and train a guidance network to align the combined source domains with the target domain. In some embodiments, a guidance network centered prediction alignment may be performed by enforcing divergence regularizations over the prediction probability distributions of target images between the guidance subnetwork and each branch subnetwork so that all subnetworks can learn from each other and make similar predictions in the target domain. Such a mutual learning structure is expected to gather domain specific information from each single source domain through branch subnetworks and gather complementary common information through the guidance subnetwork, and thus embodiments of the present disclosure can improve both the information adaptation efficiency across domains and the robustness of network training.
Referring to
The mutual learning network 300 builds N+1 subnetworks 320-1 to 320-(N+1) for the N+1 source-target pairs 310-1 to 310-(N+1) for multi-source domain adaptation. The first N subnetworks 320-1 to 320-N perform domain adaptation from each source domain to the target domain, while the (N+1)-th subnetwork 320-(N+1) performs domain adaptation from the combined multi-source domains to the target domain. As the combined multi-source domains contain more information than each single source domain, it can reinforce the nonspontaneous common information shared across multi-source domains. As a result, the (N+1)-th subnetwork 320-(N+1) is used as a guidance subnetwork, while the first N subnetworks 320-1 to 320-N are used as branch subnetworks in the mutual learning network 300 of the present disclosure. The subnetworks may be various neural networks for image classification currently known or to be developed in the future, such as convolutional neural network.
All the subnetworks in the mutual learning network 300 may have the same structure, but use different training data. Each subnetwork in the mutual learning network 300 comprises a feature generator G, a domain discriminator D, and a category classifier F. As shown in
For each subnetwork, the input image data first go through the feature generator (such as feature generator 321-1) to generate high level features. Conditional adversarial feature alignment is then conducted to align feature distributions between each specific source domain (or the combined multi-source domains) and the target domain using a separate domain discriminator (such as discriminator 322-1) as an adversary with an adversarial loss Ladv. The classifier (such as classifiers 323-1) predicts the class labels of the input images based on the aligned features with classification losses LC and LE, while mutual learning is conducted by enforcing prediction distribution alignment between each branch subnetwork and the guidance subnetwork on the same target images with a prediction inconsistency loss LM. The classification losses LC and LE and adversarial loss Ladv are considered on each subnetwork, while the prediction inconsistency loss LM considered between each branch subnetwork and the guidance subnetwork.
In some embodiments, the subnetworks 320-1 to 320-(N+1) may have independent network parameters. Alternatively, some network parameters may be shared between the subnetworks 320-1 to 320-(N+1) so as to improve the training efficiency. Each feature generator may have first few layers and last few layers.
Continuing to refer to
The feature generator G (such as feature generator 321-1) and the adversarial domain discriminator D (such as discriminator 322-1) are used to achieve multi-source conditional adversarial feature alignment, which exploits the adversarial learning of the generative adversarial network (GAN) into the domain adaptation setting. For the j-th subnetwork, the feature generator Gj and the domain discriminator Dj are adversarial, wherein Dj tries to maximally distinguish the source domain data Gj(XS
Various adversarial losses for GAN may be used as adversarial loss Ladv in embodiments of the present disclosure. In some embodiments, to improve the discriminability of the induced features toward the final classification task, the label prediction results of the classifier Fj may be taken into account to perform the conditional adversarial domain adaptation with the adversarial loss Ladv of the j-th subnetwork with the example equation (1):
where pij denotes the prediction probability vector generated by the classifier Fj on image xij and pit
p
i
j
=F
j(Gj(xij)), pit
where pij is a length K vector with each entry indicating the probability that xij belongs to the corresponding classification, Φ(.,.) denotes the conditioning strategy function, which may be a simple concatenation of its two elements.
In some embodiments, a multilinear conditioning function may be used so as to capture the cross covariance between feature representations and classifier predictions to help preserve the discriminability of the features. For example, the overall adversarial loss of all the N+1 subnetworks may be an average of adversarial losses of all subnetworks, as shown in the equation (3):
The classifier Fj is used to achieve semi-supervised adaptive prediction loss. To increase the cross-domain adaptation capacity of the classifiers, the discriminability of the mutual learning network 300 on both the source domains and the target domain are taken into account. As shown in
An unsupervised entropy loss may be used as the classification loss LE in embodiments of the present disclosure. In some embodiments, for the unlabeled image from the target domain, the unsupervised entropy loss LE may be used to perform the training as shown in example equation (5).
The assumption is that if the source and target domains are well aligned, the classifier trained on the labeled source images should be able to make confident predictions on the target images and hence have small predicted entropy values. Therefore, embodiments of the present disclosure expect this entropy loss can help bridge domain divergence and induce useful discriminative features.
According to embodiments of the present disclosure, the mutual learning network 300 can achieve guidance subnetwork centered mutual learning. With the adversarial feature alignment in each branch subnetwork, the target domain is aligned with each source domain separately. Due to the existence of domain shifts among various source domains, the domain invariant features extracted and the classifier trained in one subnetwork will be different from those in another subnetwork. Under effective domain adaptation, the divergence between each subnetwork's prediction result on the target images and the true labels should be small. By sharing the same target images, the prediction results of all the subnetworks in the target domain should be consistent. Thus, to improve the generalization performance of the mutual learning network 300 and increase the robustness of network training, embodiments of the present disclosure conduct mutual learning over all the subnetworks by minimizing their prediction inconsistency in the shared target images.
Since the guidance subnetwork 320-(N+1) uses the data from all the source domains, it contains more transferable information than each branch subnetwork. Accordingly, prediction consistency may be enforced by aligning each branch subnetwork with the guidance network in terms of predicted label distribution for each target image.
In some embodiments, Kullback Leibler (KL) Divergence may be used to align the predicted label probability vector for each target image from the j-th branch network with the predicted label probability vector for the same target image from the guidance network, where KL divergence is a measure of how one probability distribution is different from another reference probability distribution. In some embodiments, the KL divergence between the predicted label probability vector of the branch subnetwork and the predicted label probability vector of the guidance subnetwork may be determined via example equation (6).
KL(pit
where pit
Alternatively, in other embodiments, a symmetric Jensen-Shannon Divergence loss may be used to improve the asymmetric KL divergence metric. In probability theory and statistics, the Jensen-Shannon divergence is a method of measuring the similarity between two probability distributions. Jensen-Shannon Divergence is based on the KL divergence, with some notable and useful differences, including that it is symmetric and it always has a finite value. It is also known as information radius or total divergence to the average. In some embodiments, the prediction inconsistency loss LM may be represented through symmetric Jensen-Shannon Divergence loss, as shown in example equation (7).
The prediction inconsistency loss LM can enforce regularizations on the prediction inconsistency on the target images across multiple subnetworks and promote mutual learning. Next, an overall adversarial loss may be set based on the above losses Ladv, LC, LE and LM. In some embodiments, the overall adversarial loss may be represented through the equation (8) by integrating the adversarial loss Ladv, the supervised cross-entropy loss LC, the unsupervised entropy loss LE, and the prediction inconsistency loss LM.
where α,β and λ denote trade-off hyperparameters, G, F and D denote the sets of N+1 feature generators, classifiers and domain discriminators, respectively.
At 502, N+1 source-target pairs are obtained, as discussed above, where the first N source-target pairs each comprises single source domain and the target domain, while the (N+1)-th pair comprises the combined source domains and the target domain. Then, an iterative training may be performed to the mutual leaning network 300.
At 504, the discriminators D are trained. For example, the parameters of all the feature generators G and classifiers F are fixed, the adversarial loss Ladv is caused to be maximum by optimizing and adjusting the parameters of the discriminators D.
At 506, the feature generators G and classifiers F are trained. For example, the parameters of all the discriminators D are fixed, and the adversarial loss Ladv, the supervised cross-entropy loss LC, the unsupervised entropy loss LE, and the prediction inconsistency loss LM are caused to be a minimum by optimizing and adjusting the parameters of the feature generators G and classifiers F. For example, in each iteration, a plurality of images (e.g., 64 images) may be sampled from each source domain, the target domain and the combined multi-source domains.
At 508, it is determined whether the iteration terminates. For example, if each loss reaches the corresponding convergence value, the iteration may terminate. Alternatively, or in addition, if the number of repetitions of an iteration reaches a threshold, the iteration may terminate.
If a termination condition(s) for the iteration is not met, the method 500 may return to 504 and repeat training discriminators D at 504 and training the feature generators G and classifiers F at 506 until the termination condition(s) is met. If the iteration terminates, at 510, the trained mutual learning network is obtained. During the training, each loss may be assigned with a separate weight, and the weight of each loss may be adjusted to ensure the mutual leaning network 300 to be well optimized.
With the training, N+1 classifiers 320-1 to 320-(N+1) in the mutual learning network 300 have been trained. The trained classifiers 320-1 to 320-(N+1) may be used to determine the labels of the target images in the target domain in a guidance subnetwork centered ensemble manner. For the i-th image in the target domain, its overall prediction probability may be determined based on the prediction probability vectors generated by all the subnetworks. For example, the overall prediction probability result may be determined via equation (9). In equation (9), the prediction result from guidance subnetwork is given weight equal to the average prediction results from the other N branch subnetworks.
Embodiments of the present disclosure propose a novel mutual learning network architecture for multi-source domain adaptation, which enables guidance network centered information sharing in the multi-source domain setting. In addition, embodiments of the present disclosure propose dual alignment mechanisms at both the feature level and the prediction level, where the first alignment mechanism is conditional adversarial feature alignment across each source-target pair, and the second alignment mechanism is centered prediction alignment between each branch subnetwork and the guidance network. Thus, by use of the mutual learning network architecture and the dual alignment mechanisms, embodiments of the present disclosure can achieve a high accuracy for image label prediction.
In some embodiments, each source domain may comprise images captured through one type of camera. For example, images in the first source domain are captured by a normal camera and labeled, images in the second source domain are captured by a wide angle camera and labeled, and images in the third source domain are computer generated images and labeled. Images in the target domain may be captured by an ultra-wide angle camera and unlabeled. According to embodiments of the present disclosure, by using the labels in the first, second and third source domains, a mutual learning network for multi-source domain adaptation can be trained and be used to generate labels of images in the target domain. In addition, images captured under different weather conditions (such as sunny day, rainy day) may be also used as different source domains.
In some embodiments, the labels of the images in the source domain are driving scenarios for automatic driving, such as expressway, city roads, country roads, airports and so forth. Based on the driving scenario determined according to the image, the vehicle may be controlled to perform corresponding actions, such as changing the driving speed. Thus, the multi-source domain adaptation method with mutual learning network of the present disclosure can facilitate the automatic drive.
As shown in
The electronic device 700 typically includes various computer storage media. The computer storage media may be any media accessible by the electronic device 700, including but not limited to volatile and non-volatile media, or removable and non-removable media. The memory 720 can be a volatile memory (for example, a register, cache, Random Access Memory (RAM)), non-volatile memory (for example, a Read-Only Memory (ROM), Electrically Erasable Programmable Read-Only Memory (EEPROM), flash memory), or any combination thereof.
As shown in
The electronic device 700 may further include additional removable/non-removable or volatile/non-volatile storage media. Although not shown in
The communication unit 740 communicates with another computing device via communication media. Additionally, functions of components in the electronic device 700 may be implemented in a single computing cluster or a plurality of computing machines that communicate with each other via communication connections. Therefore, the electronic device 700 can be operated in a networking environment using a logical connection to one or more other servers, networked personal computers (PCs), or another network node.
The input device 750 may include one or more input devices such as a mouse, keyboard, tracking ball and the like. The output device 760 may include one or more output devices such as a display, loudspeaker, printer, and the like. The electronic device 700 can further communicate, via the communication unit 740, with one or more external devices (not shown) such as a storage device or a display device, one or more devices that enable users to interact with the electronic device 700, or any devices that enable the electronic device 700 to communicate with one or more other computing devices (for example, a network card, modem, and the like). Such communication can be performed via input/output (I/O) interfaces (not shown).
The functionality described herein can be performed, at least in part, by one or more hardware logic components. For example, and without limitation, illustrative types of hardware logic components that can be used include Field-Programmable Gate Arrays (FPGAs), Application-specific Integrated Circuits (ASICs), Application-specific Standard Products (ASSPs), System-on-a-chip systems (SOCs), Complex Programmable Logic Devices (CPLDs), and the like.
Program code for carrying out methods of the present disclosure may be written in any combination of one or more programming languages. These program codes may be provided to a processor or controller of a general purpose computer, special purpose computer, or other programmable data processing apparatus, such that the program codes, when executed by the processor or controller, cause the functions/operations specified in the flowcharts and/or block diagrams to be implemented. The program code may execute entirely on a machine, partly on the machine, as a stand-alone software package, partly on the machine and partly on a remote machine or entirely on the remote machine or server.
In the context of this disclosure, a machine readable medium may be any tangible medium that may contain, or store a program for use by or in connection with an instruction execution system, apparatus, or device. The machine readable medium may be a machine readable signal medium or a machine readable storage medium. A machine readable medium may include but is not limited to an electronic, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any suitable combination of the foregoing. More specific examples of the machine readable storage medium would include an electrical connection having one or more wires, a portable computer diskette, a hard disk, a random access memory (RAM), a read-only memory (ROM), an erasable programmable read-only memory (EPROM or Flash memory), an optical fiber, a portable compact disc read-only memory (CD-ROM), an optical storage device, a magnetic storage device, or any suitable combination of the foregoing.
Further, while operations are depicted in a particular order, this should not be understood as requiring that such operations be performed in the particular order shown or in sequential order, or that all illustrated operations be performed, to achieve desirable results. In certain circumstances, multitasking and parallel processing may be advantageous. Likewise, while several specific implementation details are contained in the above discussions, these should not be construed as limitations on the scope of the present disclosure, but rather as descriptions of features that may be specific to particular embodiments. Certain features that are described in the context of separate embodiments may also be implemented in combination in a single implementation. Conversely, various features that are described in the context of a single implementation may also be implemented in multiple embodiments separately or in any suitable sub-combination.
Although the present disclosure has been described in language specific to structural features and/or methodological acts, it is to be understood that the subject matter specified in the appended claims is not necessarily limited to the specific features or acts described above. Rather, the specific features and acts described above are disclosed as example forms of implementing the claims.
The descriptions of the various embodiments of the present disclosure have been presented for purposes of illustration, but are not intended to be exhaustive or limited to the embodiments disclosed. Many modifications and variations will be apparent to those of ordinary skill in the art without departing from the scope and spirit of the described embodiments.