The present disclosure relates generally to machine learning models and neural networks, and more specifically, to contrastive learning with self-labeling refinement.
Supervised learning for neural models usually require a large amount of manually annotated training data, which can be time-consuming and expensive. Self-supervised learning (SSL), or unsupervised visual representation learning, provide a training mechanism for the neural model to learn features without manual annotations. Such SSL methods can often be successful in many downstream tasks, e.g. image classification and object detection. Specifically, SSL constructs a pretext task that can obtain data labels via designing the task itself, and then builds a network to learn from these tasks. For instance, by constructing jigsaw puzzle, spatial arrangement identification, orientation, or chromatic channels as a pretext task, SSL learns high-quality features from the pretext task that can be well transferred to downstream tasks.
Contrastive learning is a recently developed SSL method, which constructs an instance discrimination pretext task to train a network so that the representations of different augmentations or crops of the same instance are pulled close to each other, while representations of different instances are pushed away from each other. Specifically, for an image crop query, contrastive learning randomly augments the same image to obtain a positive instance and view other image crops as negatives. Then it constructs a one-hot label for instance discrimination over the positive and negative instances to pull the positive pair closer while pushing away negative instances in the feature space. The one-hot labels used in contrastive learning, however, can often be inaccurate and uninformative. This is because a query can often be semantically similar or even more similar to its negative instances than the corresponding positive instances.
Therefore, there is a need to improve the accuracy of contrastive learning.
In the figures and appendix, elements having the same designations have the same or similar functions.
Contrastive learning is a self-supervised learning method, which usually learn through an augmented positive instance of a training instance, paired with a negative instance to the training instance that are both input to the neural network. During training, the representations of different augmentations or crops of the same instance are pulled close to each other, while representations of different instances are pushed away from each other. Thus, without any pre-annotated training instances, contrastive learning methods often generate “artificially” labeled data by assuming augmentations of the same instance are positives and augmentations of other instances are negative. Such label assignments can be noisy and impairs the generalization performance, because sometimes negative instances can be semantically similar to the original instance, or even share the same semantic class as the original instance.
For example, a query could be semantically similar or even more similar to some negatives than its positives. Indeed, some negatives may even belong to the same semantic class as the query. This is because, to achieve satisfactory performance, one often uses sufficient negatives that are much more than the semantic class number, which unavoidably leads to the issue on negatives. In addition, even for the same image, especially for images with several different objects which occurs in ImageNet, random augmentations, e.g. crop, can provide crops with (slightly) different semantic information, and thus some of the huge negatives could be (more) similar to query. Hence, the one-hot label does not well reveal the semantically similarity between the query and its positives and “negatives”, and thus cannot guarantee the semantically similar samples to close each other, leading to performance degradation of contrastive learning.
In view of the need to improve the accuracy of contrastive learning, embodiments described herein provide a contrastive learning mechanism with self-labeling refinement, which iteratively employs the network and data themselves to generate more accurate and informative soft labels for contrastive learning. Specifically, the contrastive learning framework includes a self-labeling refinery module to explicitly generate accurate labels, and a momentum mix-up module to increase similarity between a query and its positive, which in turn implicitly improves label accuracy.
For example, given a query, the self-labeling refinery module adopts a positive instance of the query to estimate semantic similarity between the query and its keys (i.e. its positive and negatives) by computing their feature similarity. This is because a query and its positive come from the same image and should have similar semantic similarity on the same keys. Then the self-labeling refinery module linearly combines the estimated similarity of a query with its vanilla one-hot label in contrastive learning to iteratively generate more accurate and informative soft labels. Thus, in this way, at early training stage, one-hot labels have greater combination weights to provide relatively accurate labels. As the training process progresses, the estimated similarity becomes more accurate and informative, and thus the combination weight for the similarity becomes larger. This is because the similarity captures useful underlying semantic information between the query and its keys, which can sometimes be missing from the one-hot labels. This strategy is both empirically and theoretically effective.
In this way, even when the semantic labels in the instance discrimination task for contrastive learning are corrupted, the generated self-labeling may recover the true semantic labels of training data. Thus, networks trained with self-labeling may more accurately predict the true semantic labels of test samples.
In one embodiment, the momentum mix-up module for contrastive learning to further reduces the possible label noise and also increases augmentation diversity. For example, a dataset of queries {xi}i=1n and their corresponding positives {{tilde over (x)}i}i=1n may be randomly combined using a random variable as the combination weight: xi′=θxi+(1−θ){tilde over (x)}k. The estimated label corresponding to an input of xi′ is yi′=ƒ
As used herein, the term “network” may comprise any hardware or software-based framework that includes any artificial intelligence network or system, neural network or system and/or any training or learning models implemented thereon or therewith.
As used herein, the term “module” may comprise hardware or software-based framework that performs one or more functions. In some embodiments, the module may be implemented on one or more neural networks.
Given a batch of training images {ci}i=1s at each iteration, each original image sample ci is randomly augmented into two views (xi, {tilde over (x)}i) with xi being referred to as a query sample and {tilde over (x)}i being referred to as the positive instance of the query sample. A set B={bi}i=1b denotes the negative keys of current query samples {xi}i=1s. For example, a large dictionary size b is often used to achieve satisfactory performance, e.g. 65,536. In one implementation, B may be updated by the minibatch features {g({tilde over (x)}i)}i=1s in the first-in and first-out order.
In one embodiment, the query sample xi 102 may be input to the online network f( ) 110, while the set of positive instances {{tilde over (x)}i}i=1s 104 and the set of negative instance B={bi}i=1b 106 may be input to the target network g( ) 120.
The online network f( ) 110 may in turn generate an encoded query representation q=f(xi) 112, while the target network g( ) 120 may generate a set of encoded key representations {g({tilde over (x)}i)}i=1s∪{g(bi)}bi=1b 122. A similarity metric between the encoded query representation 112 and the set of encoded key representations 122 is then computed at similarity module σ(,) 125, which computes a similarity of two representations in the feature space, e.g.,
with a temperature parameter τ.
For example, the similarity module 125 may compute a similarity metric σ(xi,
The similarity metrics computed from module 125 may be sent to the contrastive loss module 130. Specifically, the contrastive loss module 130 computes a loss 135 based on the self-labels 129 from a self-label computation 128 and the similarity metrics:
where w denotes the parameters of the online network f( ),
The online network f( ) 110 is then updated by the computed loss 135, e.g., via backpropagation, by fixing the parameters of the target network g( ) 120. The target network g ( ) 120 is then updated via exponential moving average (EMA), e.g., ξ=(1+ι)ξ+ιω where ξ denotes the parameters of g( ) and ι∈(0,1) is a constant.
In view of this issue, a self-labeling refinery module 128 employs network and data themselves to improve the quality of inaccurate labels during training, which generates more accurate and informative labels, and improves the performance of contrastive learning. Specifically, to refine the one-hot label yi of query xi, the positive instance {tilde over (x)}i is input to the online network f( ) 110 and instance in the set
To this end, at the t-th iteration, the instance-class probability pit∈s+b of xi on the set
where
On the other hand, as {tilde over (x)}i is highly similar to itself, piit could be much larger than others and conceals the similarity of other semantically similar instances in
Then, the combination module 138 combines the one-hot label yi and two label estimations, i.e. pi and qi, to obtain the more accurate, robust and informative label
i
t=(1−αt−βt)yi+αtpit+βtqit,
where αt and βt are two constants. In one implementation, αt=μmaxkpikt/z and βt=μmaxkqikt/z, where z=1+μmaxkpikt+μmaxkqikt, the constants 1, maxkpikt and maxkqikt respectively denote the largest confidences of labels yi, pit and qit on a certain class. Here hyperparameter μ controls the prior confidence of pt and qt. So the self-label refinery only has two parameters τ′ and μ to tune.
At step 302, a training batch of unlabeled queries is received, e.g., a mini-batch of queries {xi}i=1s.
At step 304, for each unlabeled query, a positive instance {tilde over (x)}i paired with the query sample is generated.
At step 306, a first instance probability distribution is computed based on a first semantic similarity between the first positive instance and a set of positive instances and negative instances generated from the training batch of query samples. For example, at the t-th iteration, the instance-class probability pit∈s+b of xi on the set
At step 308, the first positive instance corresponding to the first unlabeled query sample is removed from the set of positive instances and negative instances.
At step 310, a second instance probability distribution is computed based on a second semantic similarity between the first positive instance and remaining instances in the set of positive instances and negative instances. For example, {tilde over (x)}i is removed from the set
At step 312, a first self-label is generated by combining a one-hot label of the first unlabeled query, the first label estimation and the second label estimation. For example, the one-hot label yi and two label estimations, i.e. pi and qi, are linearly combined to obtain more accurate, robust and informative label
At step 314, an encoded output based on a contrastive input of the set of positive instances and negative instances is generated by a machine learning model, e.g., networks f( ) and g( ).
At step 316, a contrastive loss objective may be computed based at least in part on the generated self-label, e.g., as described in relation to module 130 in
At step 318, the machine learning model is updated based on the contrastive loss objective via backpropagation. For example, the online network f( ) 110 is updated by the computed loss 135, e.g., via backpropagation, by fixing the parameters of the target network g( ) 120. The target network g( ) 120 is then updated via exponential moving average (EMA), e.g., ξ=(1−ι)ξ+ιω where ξ denotes the parameters of g( ) and ι∈(0,1) is a constant.
Method 400 uses momentum mix-up to further reduce the possible label noise in realistic data and increase the diversity of data as well. Continuing on from step 304 in
x
i
′=θx
i+(1−θ){tilde over (x)}k,yi′=θyi+(1−θ)
where {tilde over (x)}k is randomly sampled from the key set {{tilde over (x)}i}i=1s,
At step 408, a second contrastive loss objective may be computed by using the virtual sample and the virtual label, e.g., to replace xi and the self-label
Thus, in this way, momentum mix up can further improve the accuracy of the label yi′ compared with the traditionally used one hot labels. The virtual sample xi′ has two positive keys xi and {tilde over (x)}k. Accordingly, the component {tilde over (x)}k in xi′=θxi+(1−θ){tilde over (x)}k directly increases the similarity between the query xi′ and its positive key {tilde over (x)}k in
Another advantage of momentum mix-up is strong augmentation. It is observed that directly using strong augmentation in contrastive learning may lead to performance degradation, as the instance obtained by strong augmentation often heavily differs from the one with weak augmentation. As aforementioned, the component {tilde over (x)}k in xi′=θxi+(1−θ){tilde over (x)}k increases the similarity between the query instance xi′ and the key instance {tilde over (x)}k in
At step 410, a weighted sum of a contrastive loss objective computed based on one-hot labels and the second contrastive loss objective may be optionally computed as a training objective. For example, the combined training objective may be defined as:
(w)=(1−λ)c(w,{(xi,yi)})+λc(w,{(xi′,yi′)}),
where c(w,{(xi,yi)}) denotes the vanilla contrastive loss with one-hot label yi, c(w,{(xi′,yi′)}) denotes the momentum mix up loss with label yi′ estimated by our self-labeling refinery, and λ is a constant. Method 400 may then proceed to step 318 in
Memory 520 may be used to store software executed by computing device 500 and/or one or more data structures used during operation of computing device 500. Memory 520 may include one or more types of machine readable media. Some common forms of machine readable media may include floppy disk, flexible disk, hard disk, magnetic tape, any other magnetic medium, CD-ROM, any other optical medium, punch cards, paper tape, any other physical medium with patterns of holes, RAM, PROM, EPROM, FLASH-EPROM, any other memory chip or cartridge, and/or any other medium from which a processor or computer is adapted to read.
Processor 510 and/or memory 520 may be arranged in any suitable physical arrangement. In some embodiments, processor 510 and/or memory 520 may be implemented on a same board, in a same package (e.g., system-in-package), on a same chip (e.g., system-on-chip), and/or the like. In some embodiments, processor 510 and/or memory 520 may include distributed, virtualized, and/or containerized computing resources. Consistent with such embodiments, processor 510 and/or memory 520 may be located in one or more data centers and/or cloud computing facilities.
In some examples, memory 520 may include non-transitory, tangible, machine readable media that includes executable code that when run by one or more processors (e.g., processor 510) may cause the one or more processors to perform the methods described in further detail herein. For example, as shown, memory 520 includes instructions for a contrastive learning module 530 that may be used to implement and/or emulate the systems and models, and/or to implement any of the methods described further herein. In some examples, the contrastive learning module 530, may receive an input 540, e.g., such as unlabeled image instances, via a data interface 515. The data interface 515 may be any of a user interface that receives a user uploaded image instance, or a communication interface that may receive or retrieve a previously stored image instance from the database. The contrastive learning module 530 may generate an output 550, such as classification result of the input 540.
In some embodiments, the contrastive learning module 530 may further includes the self-labeling module 531 and a momentum mix-up module 532. Further functionality of the self-labeling module 531 may be discussed in relation to
The dataset ResNet50 with a 3 layered MLP head for CIFAR10 and ImageNet are used. The contrastive learning with self-labeling refinement model, referred to as CLEAN, is first pretrained and then train a linear classifier on top of 2048-dimensional frozen features provided by ResNet50. With dictionary size 4,096, it is pretrained for 2,000 epochs on CIFAR10. Dictionary size on ImageNet is 65,536. For linear classifier, it is trained for 200 and 100 epochs on CIFAR10 and ImageNet, respectively.
Standard data augmentations are used as described in He et al., Momentum contrast for unsupervised visual representation learning. In Proc. IEEE Conf. Computer Vision and Pattern Recognition, pp. 9729-9738, 2020, for pretraining and test unless otherwise stated. For example, for test, normalization is performed on CIFAR10, and employ center crop and normalization on ImageNet. For CLEAN, we set τ=0.2, τ′=0.8, κ=2 in Beta(κ,κ) on CIFAR10, and τ=0.2, τ′=1, κ=0.1 on ImageNet. For confidence μ, it is increased as μt=m2−(m2−m1)(cos(πt/T)+1)/2 with current iteration t and total training iteration T. Set m1=0, m2=1 on CIFAR 10, and m1=0.5, m2=10 on ImageNet. For KNN on CIFAR10, its neighborhood number is 50 and its temperature is 0.05.
For CIFAR 10, to fairly compare with Lee et al., Mixco: Mix-up contrastive learning for visual representation. arXiv preprint arXiv:2010.06300, 2020, each image is cropped into two views to construct the loss. For ImageNet, CLSA (described in Wang & Qi, Contrastive learning with stronger augmentations, 2021) and train CLEAN in two settings. CLEAN-Single uses a single crop in momentum mixup loss c(w,{(xi′,yi′)}) that crops each image to a smaller size of 96×96, without much extra computational cost to process these small images. CLEAN multi-crops each image into five sizes 224×224, 192×192, 160×160, 128×128, and 96×96 and averages their momentum mixup losses. This ensures a fair comparison with CLSA and SwAV. Moreover, we use strong augmentation strategy in CLSA.
Specifically, for the above small image, an operation is randomly selected from 14 augmentations used in CLSA, and apply it to the image with a probability of 0.5, which is repeated 5 times. “(strong)” is used to mark whether we use strong augmentations on the small images in momentum mix-up loss. Thus, CLEAN has almost the same training cost with CLSA. For vanilla contrastive loss on ImageNet, weak augmentations are always used.
From Table 1 in
Table 2 in
The pretrained CLEAN model is also pretrained on VOC (described in Everingham et al., The pascal visual object classes (voc) challenge, Int'l. J. Computer Vision, 88(2):303-338, 2010) and COCO (Lin et al., Microsoft coco: Common objects in context. In Proc. European Conf. Computer Vision, pp. 740-755. Springer, 2014). For classification, a linear classifier is trained upon ResNet50 100 epochs by SGD. For object detection, the same protocol is used in He et al., Momentum contrast for unsupervised visual representation learning. In Proc. IEEE Conf. Computer Vision and Pattern Recognition, pp. 9729-9738, 2020, to fine-tune the pretrained ResNet50 based on detectron2 (described in Wu et al., detectron 2, 2019) for fairness. On VOC, detection head is trained with VOC07+12 train val data and tested on VOC07 test data. On COCO, the head is trained on train2017 set and evaluate on the val2017.
Table 3 in
CLEAN is trained for 1,000 epochs on CIFAR10 to investigate the effects of each component in CLEAN using strong augmentation. Table 4 in
Then the momentum mix-up is compared with vanilla mix-up in the concurrent works (described in Kim et al., Mixco: Mix-up contrastive learning for visual representation. arXiv preprint, arXiv:2010.06300, 2020; and Lee et al., Hi-mix: A strategy for regularizing contrastive representation learning. arXiv preprint arXiv:2010.08887, 2020). Specifically, one-hot label is used in MoCo and replace {tilde over (x)}j in with the query xj to obtain “MoCo+mix-up”, and CLEAN with one-hot label can be viewed as “MoCo+momentum mixup”. Then these methods are trained for 1,000 epochs on CIFAR 10 with weak/strong augmentation, and 200 epochs on ImageNet with weak augmentations. Table 6 in
Performance of the self-labeling refinery on label-corrupted data is analyzed as follows. Let {ci}i=1K⊂d be K vanilla samples belonging to
Definition 1((ρ,ε,δ)-corrupted dataset). Let {(xi,yi*)}i=1n denote the pairs of crops (augmentations) and ground-truth semantic label, where crop xi generated from the t-th sample ct obeys ∥xi−ct∥2≤ε with a constant ε, and yi*∈{γt}i=1
with two constants cl and cu. Moreover, the classes are separated:
|γi−γk|≥δ,∥ci−ck∥2>2ε,(∀i≠k),
where δ is the label separation. A (ρ,ε,δ)-corrupted dataset {(xi,yi)}i=1n obeys the above conditions but with corrupted label {yi}i=1n. Specifically, for each sample ci, at most ρni augmentations are assigned to wrong labels in {γi}i=1
Then a network of one hidden layer is studied as an example to investigate the label refining performance of the method:
x∈
d
f(W,x)=vTϕ(Wx)
where W∈k×d and v∈k are network parameters, and ϕ is an activation function. v is fixed to be a unit vector where half the entries are 1/√{square root over (k)} and other half are −1/√{square root over (k)} to simplify exposition. So it is only optimized over W that contains most network parameters and will be shown to be sufficient for label refinery. Then given a (ρ,ε,δ)-corrupted dataset {(xi,yi)}i=1n, at the t-iteration we train the network via minimizing the quadratic loss:
t(W)=½Σi=1n(
Here the label
w
t+1
=w
t−η∇t(Wt),
where η is a learning rate. According to network convergence analysis, gradient descent and quadratic loss may be used, since (i) gradient descent is expectation version of stochastic one and often reveals similar convergence behaviors; (ii) one can expect similar results for other losses, e.g. cross entropy, but quadratic loss gives simpler gradient computation. For analysis, mild assumptions are imposed on the network and the self-labeling refinery, which are widely used in network analysis.
Assumption 1. For the network (6), suppose the activation ϕ and its first- and second-order derivatives obey |ϕ(0)|, |ϕ′(z)|, |ϕ″(z)|≤Γ for ∀z and some Γ≥1. Moreover, the entries of initialization W0 obey i.i.d. (0,1).
Assumption 2. Define network covariance matrix Σ(C)=(CCT)⊙u[ϕ′(Cu)ϕ′(Cu)T] where C=[c1 . . . cK]T, u˜(0,I),⊙ is the elementwise product. Let λ(C)>0 be the minimum eigenvalue of Σ(C). For label refinery, assume
with three constants ψ1,ψ2 and c1. Here αmax=maxi≤t≤t
Assumption 1 is mild, as most differential activation functions, e.g. softplus and sigmoid, satisfy it, and the Gaussian initialization is used in practice. It is assumed that Gaussian variance to be one for notation simplicity, but the technique is applicable to any constant variance. Assumption 2 requires that the discrepancy between αt and αt+1 until some iteration number t0 are bounded, which holds by setting proper αt. For λ(C), prior works empirically and theoretically show λ(C)>0. Based on the assumptions, we state our results in Theorem 2 with constants c1˜c6.
Theorem 2. Assume {(xi,yi)}i=1n is a (ρ,ε,δ)-corrupted dataset with noiseless labels {yi*}i=1n. Let
Suppose ε and the number k of hidden nodes satisfy
If step size
with probability 1−3/K100−K exp(−100d), after
iterations, the gradient descent (7) satisfies:
(1) The discrepancy between the label
where ζ=4ρ+c5εψ′KΓ3ξ√{square root over (log K)}/λ(C), y*=[y1*, . . . , yn*]. Moreover, if
the estimated label
(2) By using the refined label
where f(Wt,X)=[f(Wt,x1), . . . , f(Wt,xn)]. If assumptions on ρ and ε in (1) hold, for vanilla sample ck (∀k=1 . . . K), network f(Wt,⋅) predicts the true semantic label γk of its any augmentation x that obeys ∥x−ck∥2≤εi:γk*=γk with k*=argmini≤i≤
The first part result in Theorem 2 shows that after training iterations t0, the discrepancy between the label
The second result in Theorem 2 shows that by using the refined label
Some examples of computing devices, such as computing device 100 may include non-transitory, tangible, machine readable media that include executable code that when run by one or more processors (e.g., processor 110) may cause the one or more processors to perform the processes of method 300. Some common forms of machine readable media that may include the processes of method 300 are, for example, floppy disk, flexible disk, hard disk, magnetic tape, any other magnetic medium, CD-ROM, any other optical medium, punch cards, paper tape, any other physical medium with patterns of holes, RAM, PROM, EPROM, FLASH-EPROM, any other memory chip or cartridge, and/or any other medium from which a processor or computer is adapted to read.
This description and the accompanying drawings that illustrate inventive aspects, embodiments, implementations, or applications should not be taken as limiting. Various mechanical, compositional, structural, electrical, and operational changes may be made without departing from the spirit and scope of this description and the claims. In some instances, well-known circuits, structures, or techniques have not been shown or described in detail in order not to obscure the embodiments of this disclosure Like numbers in two or more figures represent the same or similar elements.
In this description, specific details are set forth describing some embodiments consistent with the present disclosure. Numerous specific details are set forth in order to provide a thorough understanding of the embodiments. It will be apparent, however, to one skilled in the art that some embodiments may be practiced without some or all of these specific details. The specific embodiments disclosed herein are meant to be illustrative but not limiting. One skilled in the art may realize other elements that, although not specifically described here, are within the scope and the spirit of this disclosure. In addition, to avoid unnecessary repetition, one or more features shown and described in association with one embodiment may be incorporated into other embodiments unless specifically described otherwise or if the one or more features would make an embodiment non-functional.
Although illustrative embodiments have been shown and described, a wide range of modification, change and substitution is contemplated in the foregoing disclosure and in some instances, some features of the embodiments may be employed without a corresponding use of other features. One of ordinary skill in the art would recognize many variations, alternatives, and modifications. Thus, the scope of the invention should be limited only by the following claims, and it is appropriate that the claims be construed broadly and in a manner consistent with the scope of the embodiments disclosed herein.
The present disclosure is a nonprovisional of and claims priority under 35 U.S.C. 119 to co-pending and commonly-owned U.S. provisional application No. 63/146,170, filed Feb. 5, 2021, which is hereby expressly incorporated by reference.
Number | Date | Country | |
---|---|---|---|
63146170 | Feb 2021 | US |