This application is based upon and claims the benefit of priority of the prior India Provisional Application No. 202231028920, filed on May 19, 2022, the entire contents of which are incorporated herein by reference.
The embodiment discussed herein is related to machine learning technology.
There has been a tendency that, in a case where machine learning algorithm is applied to abnormality detection and medical image diagnosis so as to train a classifier (machine learning model) by using a training data set, the training data set becomes a data set with class imbalance. For example, in a case where training of a classifier for abnormality detection is executed, non-abnormality labels are provided for most of the training data sets. In a case where training of a classifier for medical diagnosis is executed, non-abnormality labels are also widely provided for training data sets.
In a case where training of a classifier is executed by using a training data set with class imbalance, it is difficult to appropriately evaluate the performance of machine learning algorithm depending on only the accuracy of the classifier, and thus a more complicated metric (index) is used in some cases. The more complicated metric is, for example, an metric that is not suitable for optimization by using cross entropy often used as a loss function.
Hereinafter, a conventional technology 1 and a conventional technology 2 will be explained.
First, a basic metric used in the conventional technology 1 will be explained. A classifier is defined as “F: X→(K)”. Note that “X” indicates a space of an input. Note that (K)={1, . . . , K} is a set of labels.
A K-by-K confusion matrix C(F) is defined as indicated in formula (1). In formula (1), “D” indicates a distribution of data. In formula (1), “1” is an indicator function. In a case where “y=i, F(x)=j” is satisfied in the indicator function, a value of the indicator function is “1”, and in a case where “y=i, F(x)=j” is not satisfied, a value of the indicator function is “0”. Note that “E” in formula (1) corresponds to calculation for an expectation value.
C
ij(F)=E(x,y)˜D(1(y=i, F(x)=j)) . . . (1)
A class distribution is defined by formula (2) for each “i”.
πi=P(y=i) . . . (2)
An accuracy acc(F) of a classifier is defined by formula (3). For example, the accuracy corresponds to a proportion of the number of correctly-answered data to all data that are input to a classifier.
acc(F)=Σk=1KCkk(F) . . . (3)
A recall reci(F) for each class of a classifier is defined by formula (4). The recall corresponds to a proportion of actually determined data to data to be determined. For example, the recall indicates how many classifiers, among a plurality of data that are to be classified into a first class, are classified into the first class.
reci(F)=Cii(F)/P(y=i) . . . (4)
A precision preci(F) for each class of a classifier is defined by formula (5). The precision corresponds to a proportion of actually correct answers to the number of determination counts of “data to be determined”. For example, the precision is a proportion of data to be actually classified into a first class to a plurality of data having been classified into the first class by a classifier.
preci(F)=Cii(F)/Σk=1KCki(F) . . . (5)
A proportion estimated to be a class i by a classifier is defined as a coverage. A coverage covi(F) is defined by formula (6).
covi(F)=Σk=1KCki(F) . . . (6)
Herein, the worst recall is defined by formula (7). The worst recall is an metric that is useful for a data set with class imbalance.
Similarly, in a case where a data set is a data set with class imbalance, a problem is that estimation of a classifier is biased toward a specific class, and hence, optimization under a coverage constraint is important. Formula (8) is one example of an metric for executing optimization on an average recall under a coverage constraint. For example, formula (8) is an metric for maximizing a total value of recalls of classes 1 to K under a condition that a coverage is equal to or more than “0.95×πi”.
Moreover, an metric for executing optimization under a constraint related to a precision is also provided, and is indicated by formula (9). For example, formula (9) is an metric for maximizing an accuracy acc(F) under a condition that a precision is equal to or more than “τ (threshold)”.
An metric where optimization is difficult as explained in formulae (7), (8) and (9) as mentioned above and the like leads to cost-sensitive learning. The cost-sensitive learning is indicated by formula (10). In the cost-sensitive learning, maximization is sought by using a gain matrix G (gain Matrix).
For example, the worst recall is indicated as formula (11) by continuous relaxation. In formula (11), ΔK−1⊂RK is a probability simplex. For example, “ΔK−1” indicates a set of K-dimensional vectors where each component has a positive value and a total of values of the components is one. In the worst recall, a gain matrix G is given by formula (12). In formula (12), “δ” is Kronecker delta.
With respect to a coverage constraint, a gain matrix G is given by formula (13) with a Lagrange factor λ ∈ RK. With respect to λi in formula (13), λj≥0 is satisfied for all “j”.
Learning of λ and cost-sensitive learning are alternately repeated so as to execute learning for an original metric. The original metric is the worst recall, an average recall under a coverage constraint, and the like.
Herein, a cross-entropy loss function used in a general machine learning is not appropriate for the cost-sensitive learning, and hence, the conventional technology 1 proposes a loss function for cost-sensitive learning.
For example, a gain matrix G is decomposed as “G=MD”. “M” and “D” are K-by-K matrices, and “D” is a diagonal matrix. There are some decomposition manners, and “D” is herein defined by formula (14) or formula (15), for example.
When assuming that an output probability of a classifier is p(x) and labels y=1, . . . , K, a hybrid loss function is defined by formula (16). Formula (17) defines ri(x) included in formula (16).
In a case where a gain matrix G is a diagonal matrix, a hybrid loss function indicated in formula (16) is referred to as a logit-adjustment (LA) loss function. In the conventional technology 1, parameters of a classifier are trained so as to minimize an expectation value E indicated in formula (18) in a case of (x, y)˜D.
E
(x,y)˜D[lhyb(y,p(x))] . . . (18)
Subsequently, the conventional technology 2 will be explained. The conventional technology 2 executes semi-supervised learning with the use of labeled training data and unlabeled training data. Assume that labeled training data are {(xb, yb): b=1, . . . , B}. Assume that unlabeled training data are {ub ∈ X: b=1, . . . , μB}.
In
In the conventional technology 2, the data Im1-1 are input to a model 10, and then an output probability p1-1 is output from the modal 10. In the conventional technology 2, the data Im1-2 are input to the model 10, and then an output probability p1-2 is output from the model 10. In the conventional technology 2, a pseudo-label 5 is generated on the basis of the output probability p1-2. For example, the Pseudo-label 5 is an output probability in which the maximum component in components of the output probability p1-2 is set at “1” and the other components are set at “0”. In the conventional technology 2, training of the model 10 is executed with the use of a loss function using cross-entropy between the output probability p1-1 and the pseudo-label 5.
in the following explanation, strongly-augmented unlabeled training data are appropriately denoted by A(ub). Weak-augmented unlabeled training data re denoted by α(ub).
In the conventional technology 2, training of the model 10 is executed with the use of a loss function L in formula (19). In formula (19), a loss function ls is a loss function for labeled training data. A loss function lu is a loss function for unlabeled training data. λu is set at a value equal to or more than zero.
L=l
S+λulu . . . (19)
The loss function 1 is defined by formula (20). In formula (20), yb corresponds to a label that is set for training data. qb is an output probability in weak augmentation, and is defined by “qb=p(α(ub))”. For example, cross-entropy is expressed by H(p1, p2) with respect to two probabilities p1 and p2. H(yb, qb) is cross-entropy between yb and qb.
The loss function lu is defined by a difference between an output probability for strong augmentation of unlabeled training data ub and an output probability for weak augmentation of the unlabeled training data ub.
Specifically, the loss function lu is defined by formula (21). qb is an output probability “qb=p(α(ub))” in a case where weakly-augmented unlabeled training data are input to a classifier. “p” is an output probability of the model 10. “q′b” is a one-hot vector where only an argmax(qb)-th component is “1”. p(A(ub)) is an output probability in a case where strongly-augmented unlabeled training data are input to a classifier.
In formula (21), “τ” is a parameter of algorithm. “1 (max qb>τ)” indicates that only training data that provide a reliable predicted label, in an unlabeled training data set, are used for training of the model 10. The predicted label corresponds to the pseudo-label having been explained with reference to
For example, related arts are disclosed in Narasimhan, H., Menon, A. K.: Training over-parameterized models with non-decomposable objectives, NeurIPS 2021 and Sohn et al., FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence, NeurIPS 2020 .
According to an aspect of an embodiment, an information processing apparatus includes one or more memories; and one or more processors coupled to the one or more memories, the one or more processors being configured to decide a gain matrix based on an input metric, perform selection of first training data from a plurality of unlabeled training data, to be used for training a machine learning model, based on the gain matrix, and perform training of the machine learning model based on the first training data, a predicted label that is predicted from the first training data, and a loss function including the gain matrix.
The object and advantages of the invention will be realized and attained by means of the elements and combinations particularly pointed out in the claims.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory and are not restrictive of the invention, as claimed.
For example, if the conventional technology 1 and the conventional technology 2 are simply combined, an metric where optimization is difficult leads to cost-sensitive learning, and then optimization is executed thereon by a method of semi-supervised learning. In this case, training of a classifier is executed by using, in addition to a labeled training data set, training data that provide a reliable predicted label, in an unlabeled training data set. Specifically, unlabeled training data corresponding to qi that satisfies “1 (max qb>τ)” indicated in formula (21) are used.
However, in the technology obtained by simply combining the conventional technology 1 and the conventional technology 2, in an metric where optimization is difficult, most of qb corresponding to unlabeled training data do not satisfy “1 (max qb>τ)”. In other words, an actual problem is that, even in a case where many unlabeled training data corresponding to a predicted label with high reliability are included, most of the unlabeled training data are rarely used for training of a classifier.
Preferred embodiments of the present invention will be explained with reference to accompanying drawings. The present invention is not limited to embodiments described below. Moreover, embodiments may be combined within a consistent range.
An information processing apparatus according to the present embodiment trains a parameter of a classifier (machine learning model) by using a loss function that includes a loss function for labeled training data and a loss function for unlabeled training data.
In the present embodiment, unlabeled training data that provide a reliable predicted label are defined by using a Kullback-Leibler divergence (KL-divergence) in a loss function for unlabeled training data. A KL-divergence indicates a pseudo-distance between two probability distributions that indicates a degree of a similarity between the two probability distributions.
For example, the information processing apparatus selects corresponding unlabeled training data as unlabeled training data that provide a reliable predicted label, in a case where a condition indicated in formula (22) is satisfied. “qb” in formula (22) is an output probability that is an output probability “qb=p(α(ub))” in a case where weak data augmentation is applied to unlabeled. training data and data to which the weak data augmentation has been applied are input to a classifier. DKL indicates a KL-divergence. “τ” is a parameter of algorithm that is set preliminarily.
In formula (22) , y′ is a predicted label (pseudo-label) that s defined by formula (23).
ý=argmax qb=argmax p(α(ub)) . . . (23)
A definition of formula (22) is based on the theory that qb that converges to a value of formula (24).
Additionally, the information processing apparatus specifies a definition of a gain matrix G in formula according to a specified metric. The information processing apparatus receives specifying of an metric externally.
G
ij=(1+λi)δij−τλj . . . (25)
The information processing apparatus trains a parameter of a classifier by using a loss function L′ in formula (26). In a loss function in formula (26), a loss function ls is a loss function for labeled training data. A loss function l′u is a loss function for unlabeled training data. λu is set at a value equal to or more than 0.
L′=l
S+λuĺu . . . (26)
A loss function ls is defined by formula (26a) as mentioned below.
A loss function l′u is defined by formula (27). As formula (27) is compared with formula (21), “1 (max qb>τ)” in formula (21) is replaced by a definition that uses a KL-divergence explained in formula (22). Furthermore, in formula (27), a hybrid loss function explained in a formula (16) is used instead of a cross entropy H. qb is an output Probability “qb=p(α(ub))” in a case where unlabeled training data to which weak data augmentation has been applied are input to a classifier. p is an output probability of a classifier. “q′b” indicates a one-hot vector where only an argmax(qb)-th component is “1”. p(A(ub)) is an output probability in a case where unlabeled training data to which strong data augmentation has been applied are input to a classifier.
As described above, the information processing apparatus trains a parameter of a classifier on the basis of labeled training data, unlabeled training data, and a gain matrix G according to an metric, in such a manner that a value of a loss function L′ is minimized. Thereby, even for some indices where optimization is difficult, it is possible to execute training of a classifier by appropriately using unlabeled training data that provide a predicted label with high reliability. Indices where optimization is difficult are the worst recall, an average recall under a coverage constraint, a recall under a precision constraint, and the like, as explained in
Subsequently, a configuration example of an information processing apparatus according to the present embodiment will be explained.
The communication unit 110 executes data communication with an external device and the like through a network. The communication unit 110 may receive, from an external device, a labeled training data set 141, a unlabeled training data set 142, a validation data set 143, and the like, as mentioned later.
The input unit 120 receives an operation of a user. A user executes specifying of an metric by using the input unit 120.
The display unit 130 displays a processing result of the control unit 150.
The storage unit 140 includes the labeled training data set 141, the unlabeled training data set 142, the validation data set 143, initial value data 144, and classifier data 145. For example, the storage unit 140 is realized by a memory and the like.
The labeled training data set 141 includes a plurality of labeled training data. Labeled training data are composed of a set of input data and a correct answer label. Labeled training data are provided as {(xb, yb): b=1, . . . , B}.
The unlabeled training data set 142 includes a plurality of unlabeled training data. Unlabeled training data include input data and do not include a correct answer label. Unlabeled training data are provided as {ub∈X: b=1, . . . , μB}. A predicted label (pseudo-label) for unlabeled training data are generated by the control unit 150 be mentioned later.
The validation data set 143 includes a plurality of validation data. Validation data are composed of a set of input data and a correct answer label. The validation data set 143 is used in a case where a confusion matrix is estimated.
The initial value data 144 include an iteration number T, a learning rate ω, and the like. An iteration number T and a learning rate are used in a case where a classifier is trained.
The classifier data 145 are data of a classifier F that are a target for training. For example, a classifier F is a Neural Network (NN).
The control unit 150 includes a reception unit 151, a generation unit 152, and a training execution unit 153. For example, the control unit 150 is realized by a processor.
The reception unit 151 receives an input of an metric from the input unit 120. For example, an metric that is received by the reception unit 151 is the worst recall, an average recall under a coverage constraint, a recall under a precision constraint, and the like. The reception unit 151 outputs a received metric to the training execution unit 153.
The generation unit 152 executes strong data augmentation on unlabeled training data ub so as to generate training data A(ub). The generation unit 152 executes weak data augmentation on unlabeled training data so as to generate training data α(ub).
The generation unit 152 reads the classifier data 145 and inputs training data α(ub) to a classifier F so as to calculate an output probability qb. An output probability qb is defined as “qb=p(α(ub))”. Furthermore, the generation unit 152 calculates formula (23) as mentioned above so as to calculate a predicted label y′ for unlabeled training data.
The generation unit 152 outputs unlabeled training data ub, training data A(ub), training data α(ub), an output probability qb, and a predicted label y′ to the training execution unit 153.
The generation unit 152 repeatedly executes the process as mentioned above on each of unlabeled training data that are included in the unlabeled training data set 142. Additionally, such a process of the generation unit 152 may be executed by the training execution unit 153 as mentioned later.
The training execution unit 153 selects unlabeled training data that are used for training of a classifier F, from unlabeled training data ub, on the basis of a gain matrix according to a specified metric. For example, the training execution unit 153 selects a plurality of unlabeled training data that satisfy a condition of formula (22).
The training execution unit 153 trains a parameter of a classifier F on the basis of a selected plurality of unlabeled training data, a predicted label that corresponds to such unlabeled training data, the labeled training data set 141, and a loss function L′, in such a manner that a value of a loss function L′ is minimized. A loss function L′ is indicated in formula (26) as mentioned above.
Herein, one example of a processing procedure of the training execution unit 153 in a case where “an average recall under a coverage constraint” is specified as an metric will be explained.
As illustrated in.
The training execution unit 153 updates a Lagrange multiplier (step S103). At step S103, the training execution unit 153 executes a next process. The training execution unit 153 estimates a confusion matrix C′(Ft) by using the validation data set 143 Specifically, the training execution unit 153 estimates a confusion matrix C′(Ft) on the basis of formula (28). In formula (28), |Sval| is the number of validation data that are included in the validation data set 143.
The training execution unit 153 calculates a Lagrange multiplier λt+1 on the basis of formula (29). Furthermore, the training execution unit 153 calculates a Lagrange multiplier λt+1 on the basis of formula (29), and subsequently, specifies a value of the Lagrange multiplier λt+1 on the basis of formula (30).
λit+1=λit−ω(Σk=1KĆkí(Ft)−0.95 πi) . . . (29)
An explanation of step 5104 will be shifted to. The training execution unit 153 selects a gain matrix G that corresponds to an average recall under a coverage constraint (step S104). A gain matrix G that corresponds to an average recall under a coverage constraint is indicated in formula (31).
The training execution unit 153 updates a classifier F according to a stochastic gradient method (step S105). At step S105, the training execution unit 153 samples batches BS, Bu from SS, Su, respectively. The training execution unit 153 updates a parameter of a classifier Ft according to a stochastic gradient method on the basis of batches BS, Bu and a gain matrix in formula (31), in such a manner that a loss function L′ that is defined by formula (26) is minimized, and provides a classifier after updating as a classifier Ft+1.
The training execution unit 153 updates t according to t=t+1 (step S106). The training execution unit 153 is shifted to step S103 in a case where a condition of t>T is not satisfied (step S107, No). On the other hand, the training execution unit 153 ends such a process in a case where a condition of t>T is satisfied (step S107, Yes).
Subsequently, one example of a processing procedure of the training execution unit 153 in a case where “the accuracy under a precision constraint” is specified as an metric will be explained.
As illustrated in
The training execution unit 153 updates a Lagrange multiplier (step S203). A process at step S203 is similar to the process at step S103 in
The training execution unit 153 selects a gain matrix G that corresponds to the accuracy under a precision constraint (step S204). A gain matrix G that corresponds to the accuracy under a precision constraint is indicated in formula (32).
G
ij=(1+λit+1)δij−τλjt+1 . . . (32)
The training execution unit 153 updates a classifier F according to a stochastic gradient method (step S205). A process at step S205 is similar to the process at step S105 in
The training execution unit 153 updates t according to t=T (step S206). The training execution unit 153 is shifted to step S203 in a case where a condition of t>T is not satisfied (step S207, No). On the other hand, the training execution unit 153 ends such a process in a case where a condition of t>T is satisfied (step S207, Yes).
Next, effects of the information processing apparatus 100 according to the present embodiment will be explained The information processing apparatus 100 defines unlabeled training data that provide a reliable predicted label by using a KL-divergence, selects a gain matrix according to an input metric, and selects unlabeled training data that provide a predicted label with high. reliability. The information processing apparatus 100 trains a parameter of a classifier, on the basis of labeled training data, selected unlabeled training data, and a loss function lid that includes a gain matrix P according to an metric, in such a manner that a value of the loss function L′ is minimized. Thereby, even for some indices where optimization is difficult, it is possible to execute training of a classifier F by appropriately using unlabeled training data that provide a predicted label with high reliability.
The information processing apparatus 100 executes data augmentation on unlabeled training data, and selects corresponding unlabeled training data in a case where a pseudo-distance between a distribution of an output probability that is output when augmented data are input to a classifier and a probability distribution that is based on a gain matrix is equal to or less than a threshold. Thereby, it is possible to appropriately use unlabeled training data that provide a predicted label with high reliability.
The information processing apparatus 100 inputs training data α(ub) to which weak data augmentation has been applied to a classifier F so as to calculate an output probability qb, and calculates a predicted label y′ on tLle basis of the output probability qb. Thereby, it is possible to set a predicted label for unlabeled training data and use it for training.
The information processing apparatus 100 trains a classifier F on the basis of a value obtained by inputting, to a hybrid loss function, qb that is an output probability in a case where unlabeled training data to which weak data augmentation has been applied are input to a classifier F and p(A(ub)) that indicates an output probability in a case where unlabeled training data to which strong data. augmentation has been applied are input to the classifier F. That is, it is possible to train a classifier F by using unlabeled data.
Next, one example of a hardware configuration of a computer that realizes functions similar to those of the information processing apparatus 100 disclosed in the above-mentioned embodiment will be explained.
As illustrated in
The hard disk device 207 includes a receiving program 207a, a generation program 207b, and a training execution program 207c. The CPU 201 reads out each of the programs 207a to 207c, and deploys the read one in the RAM 206.
The receiving program 207a functions as a receiving process 206a. The generation program 207b functions as a generation process 206b. The training execution program 207c functions as a training execution process 206c.
The receiving process 206a corresponds to a process to be executed by the reception unit 151. The generation process 206b corresponds to a process to be executed by the generation unit 152. The training execution process 206c corresponds to a process to be executed by the training execution unit 153.
The programs 207a to 207c are not necessarily stored in the hard disk device 207 in advance. For example, each of the programs may be stored in a “physical medium” such as a flexible disk (FD), a Compact Disc Read Only Memory (CD-ROM) , a Digital Versatile Disc (DVD) , a magneto-optical disc, and an Integrated Circuit card (IC card), which are inserted into the computer 200. The computer 200 may read therefrom and execute each of the programs 207a to 207c.
All examples and conditional language recited herein are intended for pedagogical purposes of aiding the reader in understanding the invention and the concepts contributed by the inventor to further the art, and are not to be construed as limitations to such specifically recited examples and conditions, nor does the organization of such examples in the specification relate to a showing of the superiority and inferiority of the invention. Although the embodiment of the present invention has been described in detail, it should be understood that the various changes, substitutions, and alterations could be made hereto without departing from the spirit and scope of the invention.
Number | Date | Country | Kind |
---|---|---|---|
202231028920 | May 2022 | IN | national |