The present disclosure relates to a learning device, a learning method, and a recording medium.
In a case of training a large-scale machine learning model such as deep learning, it is known that a regularization is used to suppress overtraining. For instance, Patent Document 1 discloses a technique for updating weight parameters of a neural network using a cost function obtained by adding a regularization term to an error function.
In a conventional technique, a regularization has been uniformly carried out for all training data. For this reason, the regularization becomes weaker for simple training data to be predicted, which thus causes overfitting, or the regularization becomes stronger for training data that are difficult to predict, which thus reduces efficiency of learning.
It is one object of the present disclosure to adaptively control the strength of the regularization in deep learning depending on the training data.
According to an example aspect of the present disclosure, there is provided a learning device including:
According to another example aspect of the present disclosure, there is provided a learning method including:
According to still another example aspect of the present disclosure, there is provided a recording medium storing a program, the program causing a computer to perform a process including:
According to the present disclosure, it becomes possible to adaptively control a strength of a regularization in deep learning depending on training data.
In the following, preferable example embodiments will be described with reference to the accompanying drawings.
The interface 11 inputs and outputs data to and from an external device. Specifically, training data set used for learning are input to the learning device 100 through the interface 11.
The processor 12 is a computer such as a CPU (Central Processing Unit), and controls the entire learning device 100 by executing a predetermined program. The processor 12 may be a GPU (Graphics Processing Unit) or a FPGA (Field-Programmable Gate Array). The processor 12 executes a learning process to be described later.
The memory 13 consist of a ROM (Read Only Memory) and a RAM (Random Access Memory). The memory 13 is also used as a working memory during operations of various processes by the processor 12.
The recording medium 14 is a non-volatile and non-transitory recording medium such as a disk-shaped recording medium, a semiconductor memory, or the like, and is formed to be detachable to the learning device 100. The recording medium 14 records various programs to be executed by the processor 12. When the learning device 100 executes various types of processes, each program recorded in the recording medium 14 is loaded into the memory 13 and executed by the processor 12. The DB 15 stores the training data set entered through the interface I/F 11 as needed.
The training data set is input to the learning device 100. The training data set includes a plurality of training data xi and correct answer class yi corresponding to the respective training data xi. The training data xi is input to the inference unit 21, and the correct answer class yi is input to the loss function calculation unit 22.
The inference unit 21 performs inference using a deep learning model to be trained by the learning device 100. Specifically, the inference unit 21 includes a neural network which forms the deep learning model to be learned. The inference unit 21 performs an inference for training data xi being input, and outputs each class score v->i as the inference result. In detail, the inference unit 21 performs a class classification for the training data xi, and outputs each class score v->i which is a vector indicating a reliability score for each class. In this disclosure, for convenience, “->” indicating the vector is denoted by a superscript on a right side of “v”. Each class score v->i is input to the weight function calculation unit 24.
The loss function calculation unit 22 calculates a loss lcls,i for each class score v->i using a loss function prepared in advance. Specifically, the loss function calculation unit 22 calculates the loss lcls,i using the class score v->i for certain training data xi and a correct answer class yi for the training data xi as illustrated in an equation (1). The calculated lost lcls,i is input to the total calculation unit 23.
Meanwhile, the weight function calculation unit 24 calculates a weight for the training data xi based on the class score v->i generated by the inference unit 21. Specifically, the weight function calculation unit 24 determines each weight wi which is a single real value by the following equation (2) from the class score v->i which is an inference result for the training data xi.
As a weight function, a rapidly increasing function is chosen in a case where a confidence score of each of classes included in the class score v->i is over-estimated or under-estimated. The term “rapidly” means faster than a linear function. The condition that the weight function increases rapidly is necessary to emphasize the over-confidence score or under-confidence score included in the class score v->i. That is, by calculating each weight using the rapidly increasing function, in a case where the class score v->i includes a value of the over-confidence score or under-confidence score, the over or under value is emphasized and each weight wi becomes larger. Accordingly, this means that the selection of the weight function determines a contribution of each weight for the training data to a gradient of a regularization term discussed later, which will be described later. Note that since the weight function calculation unit 24 simply outputs a result in which a reliability score of each class included in the class score v->i is input into the weight function, the value of each weight wi to be output is not a particularly regularized value. The weight function calculation unit 24 outputs the calculated weight w; to the weight sum calculation unit 25.
The weight sum calculation unit 25 calculates a total of the weights wi for a mini-batch. The mini-batch is a set of a predetermined number (for instance, N count) of training data. Specifically, the weight sum calculation unit 25 calculates a sum S of a N count of weights wi corresponding to the N count of the training data xi by the following equation (3).
The weight sum calculation unit 25 outputs the calculated sum S to the rescale function calculation unit 26.
The rescale function calculation unit 26 calculates a rescale function, and generates a regularization term Lreg based on the sum S being input. Specifically, the rescale function calculation unit 26 generates the regularization term Lreg by the following equation (4).
In the equation (4), “g(S)” is the rescale function. As the rescale function g(S), a monotonically increasing function that increases gradually is chosen. Note that this monotonically increasing function that increases gradually is different from a mathematical “slow increasing function”.
Here, “gradually” means slower than a linear function. The condition that the rescale function g(S) is gradual is necessary to prevent the rapidly increasing weight function from increasing the gradient of the regularization term and thus making the learning unstable. In other words, because regularization may be too strong when using the weighted wi as it is in which the over-confidence score or under-confidence score due to the weight function is emphasized, the rescale function g(S) is used to adjust the whole scale of the weight wi. In this regard, the rescale function g(S) can also be regarded as regularizing the weight wi and adjusting the strength of the overall regularization. The rescale function calculation unit 26 outputs the regularization term Lreg thus obtained to the total calculation unit 23.
The total calculation unit 23 calculates a total (hereinafter, also referred to as a “total loss L”) of the loss lcls,i input from the loss function calculation unit 22 and the regularization term Lreg input from the rescale function calculation unit 26. Specifically, the total calculation unit 23 calculates the total loss L by dividing a value where a total of the loss lcls,i and the regularization term Lreg are added up over a count 1 of the training data, by a count N of the training data included in the mini-batch by the following equation (5).
Then, the total calculation unit 23 outputs the obtained total loss L to the parameter update unit 27.
The parameter update unit 27 optimizes the inference unit 21 based on the input total loss L. Specifically, the parameter update unit 27 updates parameters of the neural network forming the inference unit 21 based on the total loss L. Thus, the learning of the deep learning model constituting the inference unit 21 is performed.
As described above, according to the learning device 100 of the first example embodiment, the contribution to the regularization term of each training data can be adaptively determined by calculating the regularization term in unit of the mini-batch. Moreover, by emphasizing the over or under inference result output by the inference unit 21 using the weight function, it is possible for the learning device 100 to prevent overfitting by strengthening the regularization for easy training data and to improve learning efficiency by weakening the regularization for difficult training data. Furthermore, by adjusting the whole scale of weights using the rescale function, it is possible for the learning device 100 to regularize the weights being partially emphasized by the weight function and to adjust the strengthen of the overall regularization. As a result, it becomes possible to adaptively determine the strength of the regularization depending on the training data and obtain a higher generalization performance, that is, a classification accuracy.
In the configuration described above, the inference unit 21 corresponds to an example of an inference unit, the loss function calculation unit 22 corresponds to an example of a loss calculation unit, the weight function calculation unit 224 corresponds to an example of a weight calculation unit, the weight sum calculation unit 25 corresponds to an example of a weight calculation unit, the rescale function calculation unit 26 corresponds to an example of a regularization term calculation unit, and the parameter update unit 27 corresponds to an example of an optimization unit.
In a second example, the weighting function is a function which adds up natural logarithms of the squares of the confidence scores vie of classes included in the class score v->i over the count c of all classes. Also, the rescale function is a function which calculates a logarithm of the sum S which the weight sum calculation unit 25 outputs.
In a third example, the weight function is a function which adds up natural logarithms of positive and negative confidence scores vie of the classes included in the class score v->i over the count c of all classes. The rescale function is a function which calculates the logarithm of the sum S which the weight sum calculation unit 25 outputs.
First, the inference unit 21 performs the inference for the input training data xi (step S11). The inference unit 21 outputs the class score v->i obtained by the inference to the loss function calculation unit 22 and the weight function calculation unit 24. Based on the class score v->i, the loss function calculation unit 22 calculates the loss lcls,i using the equation (1), and outputs the loss lcis,i to the total calculation unit 23 (step S12).
Next, based on the class score v->i, the weight function calculation unit 24 calculates the weight wi using the equation (2), and outputs the weight wi to the weight sum calculation unit 25 (step S13). Next, the weight sum calculation unit 25 calculates the sum S of the weights wi for each mini-batch by the equation (3), and outputs the sum S for each mini-batch to the rescale function calculation unit 26 (step S14). Next, the rescale function calculation unit 26 calculates the regularization term Lreg from the sum S being input using the rescale function, and outputs the total calculation unit 23 (step S15). Note that the process of step S12 and the processes of steps S13 to S15 may be performed in a reverse order or in parallel in time.
Next, based on the regularization term Lreg input from the rescale function calculation unit 26 and the loss lcls,i input from the loss function calculation unit 22, the total calculation unit 23 calculates a total of losses (total loss L) using the equation (5), and outputs the total loss L to the parameter update unit 27 (step S16). Next, the parameter update unit 27 updates the parameters of the neural network forming the inference unit 21 based on the total of the losses (total loss L) (step S17).
Next, it is determined whether or not an end condition of the learning is satisfied (step S18). As the end condition, for instance, it is possible to use that all training data have been used, that an accuracy of the inference unit 21 has reached a predetermined accuracy, or the like. When the end condition has not been satisfied (step S18: No), the learning process returns to step S11, and the process of steps S1 to S17 are thus performed using next training data. On the other hand, when the end condition is satisfied (S18: Yes), the learning process is terminated.
According to the learning device 200 of the second example embodiment, in the deep learning, it becomes possible to adaptively control the strength of the regularization depending on the training data.
A part or all of the example embodiments described above may also be described as the following supplementary notes, but not limited thereto.
A learning device comprising:
The learning device according to supplementary note 1, wherein the regularization term calculation means increases a value of the regularization term for the class score that is high, and decreases the value of the regularization term for the class score that is low.
The learning device according to supplementary note 1 or 2, further comprising a loss calculation means configured to calculate a loss based on the class score and a correct answer class corresponding to the training data,
The learning device according to supplementary note 1, wherein
The learning device according to supplementary note 1, wherein
The learning device according to any one of supplementary notes 1 to 3, wherein
A learning method comprising:
A recording medium storing a program, the program causing a computer to perform a process comprising:
While the disclosure has been described with reference to the example embodiments and examples, the disclosure is not limited to the above example embodiments and examples. It will be understood by those of ordinary skill in the art that various changes in form and details may be made therein without departing from the spirit and scope of the present disclosure as defined by the claims.
| Filing Document | Filing Date | Country | Kind |
|---|---|---|---|
| PCT/JP2021/035277 | 9/27/2021 | WO |