The present disclosure relates to a data generation method, a machine learning method, an information processing apparatus, a non-transitory computer-readable recording medium storing a data generation program, and a non-transitory computer-readable recording medium storing a machine learning program.
A large number of security applications using machine learning have been developed. A machine learning model normally assumes that the data to be used for training and the data to be input at the time of testing are generated from the same distribution. However, with data that is used in the security industry, this assumption is often untrue, and data is input from a distribution that was not taken into consideration at the time of training (from outside the distribution). Data from outside the distribution may be referred to as out-of-distribution data. Also, the outside of the distribution may be referred to as out-of-distribution (OOD).
Models that are vulnerable to the outside of the distribution are incapable of recognizing out-of-distribution for them, and therefore, applications might fail to recognize new threats.
To enhance robustness for out-of-distribution data, a technique of training a machine learning model using out-of-distribution data is known. The out-of-distribution data to be used for this training is acquired from a third party or is generated in a pseudo manner, for example.
Examples of the related art include: [Patent Document 1] Japanese National Publication of International Patent Application No. 2018-529157; [Patent Document 2] Japanese Laid-open Patent Publication No. 2020-123830; and [Patent Document 3] U.S. Patent Application Publication No. 2021/0182731.
According to an aspect of the embodiments, there is provided a data generation method implemented by a computer, the data generation method including: generating pseudo data and pseudo label data for the pseudo data; and updating the pseudo data in a direction for reducing a loss of an output obtained by inputting the pseudo data to a machine learning model, to generate out-of-distribution data not included in a specific domain.
The object and advantages of the invention will be realized and attained by means of the elements and combinations particularly pointed out in the claims.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory and are not restrictive of the invention.
However, out-of-distribution data that is acquired/generated by such a conventional technique is limited, and therefore, there is the problem of vulnerabilities remaining in the models.
In one aspect, the present disclosure aims to enable efficient generation of out-of-distribution data.
In the description below, embodiments related to a data generation method, a machine learning method, an information processing apparatus, a data generation program, and a machine learning program will be described with reference to the drawings. Note that the embodiments described below are merely examples, and there is no intention to exclude adoption of various modifications and techniques not explicitly described in the embodiments. In other words, the embodiments may be modified in various manners (such as combining embodiments and the respective modifications) and be implemented in a range without departing from the scope thereof. Further, each drawing is not intended to include only the components illustrated in the drawing, but may include other functions and the like.
The information processing apparatus 1 has functions as an OOD data generation unit 100.
The OOD data generation unit 100 generates out-of-distribution (OOD) data used for training of a class determination model (machine learning model) (not illustrated in the drawing), and exploratorily generates OOD data that is difficult for a classifier of the machine learning model to recognize. The out-of-distribution data is data that is not included in a specific domain.
The machine learning model is a class determination model that determines the class (domain) of data that has been input, and includes the classifier. The classifier performs classification of input data. The classifier calculates a certainty factor of each class among a plurality of classes with respect to data that has been input, and determines that the data belongs to the class having the highest certainty factor. The classes correspond to domains.
The classifier may be denoted by reference sign C. Also, the OOD data to be generated by the OOD data generation unit 100 may be referred to as an OOD sample.
Further, the OOD data is input to a discriminator (not illustrated in the drawing). The discriminator is for determining whether input data is IND or OOD, and may output a value representing the closeness to IND data. The discriminator may be denoted by reference sign D.
In the example described below, the data to be processed is image data, and the OOD data generation unit 100 generates image data as the OOD data.
With the current classifier, the OOD data generation unit 100 generates an OOD samples x{circumflex over ( )} that has a high certainty factor and are unlike IND samples. The OOD data generation unit 100 generates OOD samples x{circumflex over ( )} expressed by the following mathematical formula (1).
In the above mathematical formula (1), Lc(x, t) represents the loss function related to the classifier (X→Y). Lc(x, t) may represent a cross entropy loss, for example. The smaller the value of Lc(x, t), the higher the probability that the sample x belongs to the target class t.
Further, Ld(x) represents the loss function related to the discriminator (X→R). R represents a set of entire real numbers. Ld(x) may represent a DeepSVDD loss, for example. Ld(x) can be expressed by the following mathematical formula (2), for example.
In the above mathematical formula (2), ϕd represents the feature extractor, and c0 represents the center of gravity.
Ld(x) may represent the L2 distance to IND on X. The greater the value of Ld(x) is, the more different the sample is from an IND sample.
Here, a is a parameter (hyperparameter) for adjusting the trade-off between Lc(x, t) and Ld(x), and a∈[0, 1]. The OOD data generation unit 100 adjusts the parameter a and optimizes the trade-off between Lc(x, t) and Ld(x), to generate an OOD sample that the classifier is not good at identifying.
For optimization of a, a known optimization technique such as a gradient descent method, a genetic algorithm (GA), or a generative adversarial network (GAN) may be used, and may be modified as appropriate prior to implementation.
In
Further, solid arrows indicate the updating directions of OOD samples when a=1 in above the mathematical formula (1). Dotted arrows indicate the updating directions of the OOD samples when a=0 in the above mathematical formula (1). Dot-and-dash arrows indicate the updating directions suitable for updating the OOD samples. In updating an OOD sample, it is preferable to update the OOD sample in a direction away from the classification boundary. In other words, it is preferable to update an OOD sample in a direction in which the certainty factor is high.
As illustrated in
The OOD data candidate generation unit 101 randomly generates OOD data candidates (pseudo data). Also, the OOD data candidate generation unit 101 randomly generates label data (pseudo label data) corresponding to the OOD data candidates.
The OOD data candidate update unit 102 updates the OOD data candidates in such directions as to reduce the loss L (such directions that the certainty factor becomes higher).
The OOD data candidate update unit 102 generates the out-of-distribution data by updating the OOD data candidates in such directions as to reduce the loss L (x, t) of the outputs to be obtained by inputting the OOD data candidates (pseudo data: x) to a machine learning model.
The classifier update unit 103 updates the classifier so as to recognize the OOD candidate data as OOD. As a result, the classifier no longer classifies the peripheries of the OOD candidate data with a high certainty factor.
By updating the classifier (machine learning model) using the updated OOD candidate data, the classifier update unit 103 can impart robustness for the OOD data to the machine learning model.
Step S1 indicates an initial state. The classifier C is in a state where machine learning (training) has been performed thereon. The OOD data candidate generation unit 101 randomly generates OOD data candidates (pseudo data).
In step S2, the OOD data candidate update unit 102 updates the OOD data candidates in such directions as to reduce the loss L (such directions that the certainty factor becomes higher).
In step S3, the classifier update unit 103 updates the classifier C so that the OOD candidate data updated in step S2 is recognized as OOD. As a result, the classifier C no longer classifies the peripheries of the OOD candidate data with a high certainty factor.
In step S4, the OOD data candidate update unit 102 generates OOD data by updating the OOD candidate data in such directions as to reduce the loss L, for the classifier C updated in step S3. Thereafter, the processes in steps S2 to S4 are repeatedly performed.
For example, the processes in steps S2 to S4 may be repeatedly performed the number of times equivalent to a preset number of optimization steps.
Also, the OOD data candidate generation unit 101 generates a plurality of OOD data candidates (pseudo data), and the OOD data candidate update unit 102 and the classifier update unit 103 repeatedly perform the processes in steps S2 to S4 on the plurality of OOD data candidates (pseudo data).
For example, the OOD data candidate generation unit 101 may generate OOD data candidates (pseudo data) until a preset number of OOD samples is reached.
Step S5 indicates a final state. In this final state, the OOD data candidates (pseudo data) are treated as the OOD data (OOD samples). In
The OOD data generation unit 100 receives an input of training data Dtr (labeled IND samples), a hyperparameter a, the number of OOD samples n, the number of optimization steps m, and learning rate Ir (see reference sign P1). The OOD data generation unit 100 merges OOD samples in each step, to output OOD data of a total of n×m samples (see reference sign P2).
In the initialization process, initialization of the classifier C and the discriminator D using the training data Dtr is performed, respectively (see reference sign P3).
Also, the OOD data candidate generation unit 101 randomly generates OOD candidate data (OOD data, or pseudo data) {xi0}i and dummy labels {ti}i (see reference sign P4). The dummy labels {ti}i indicate the classes into which the pseudo data is to be classified.
The OOD data candidate update unit 102 calculates a gradient ∇i of xis with respect to the loss L (x, ti), for the OOD candidate data xi (see reference sign P5). The OOD data candidate update unit 102 calculates the gradient ∇i of xis so that the loss L (x, ti) becomes smaller. The OOD data candidate update unit 102 may calculate the gradient ∇i, using the gradient descent method, for example.
Also, the OOD data candidate update unit 102 updates the OOD data candidates, using the learning rate Ir and the gradient ∇i (xis+1←xis−Ir×∇i: see reference sign P6).
After that, the OOD data candidate update unit 102 adds the updated xis+1 to a set DOOD of OOD data (DOOD←DOOD∪xis+1: see reference sign P7).
These processes denoted by reference signs P5 to P7 are repeatedly performed (repetitive execution) for all (n) OOD samples and all the steps of the optimization step number m.
After that, the generated OOD data (OOD candidate data) is added to the training data Dtr, and the classifier update unit 103 updates the classifier C, using these pieces of data. This update of the classifier C is performed by regularizing OOD data so that uniform outputs are made and conducting training (see reference sign P8).
Note that the discriminator D may be updated with the training data Dtr to which the generated OOD data (OOD candidate data) is added (see reference sign P9). The processed indicated by reference sign P9 may be performed as appropriate. The OOD data is also regarded as IND while machine learning is performed.
After that, final adjustment (fine tuning) of the classifier C is performed with the training data to which the generated OOD data (OOD candidate data) is added (see reference sign P10).
As described above, in the information processing apparatus 1 as an example of the first embodiment, the OOD data candidate update unit 102 updates the OOD data candidate generated by the OOD data candidate generation unit 101, so that the loss for the class is minimized.
As a result, OOD data for training with high completeness can be generated for the class determination model. Also, it is possible to generate OOD that is completely different from IND but is classified with a high certainty factor, which the classifier is not good at.
The classifier update unit 103 updates the classifier with the OOD data in the generation process that is updated by the OOD data candidate update unit 102. Thus, OOD data can be sequentially generated in an exploratory manner, and the OOD data can be efficiently generated.
Also, it is possible to enhance robustness of the machine learning model by preparing the generated OOD data and training the machine learning model.
With this information processing apparatus 1, data that cannot be recognized as out-of-distribution data by the current machine learning classifier can be generated in a simulative/exploratory manner. As OOD data candidates (pseudo data) are generated so as to oppose the classifier, data that is difficult for the classifier to recognize is generated in an exploratory manner. The classifier is then updated with the generated OOD data candidates, to impart robustness to the OOD data.
Also, by repeatedly performing generation of OOD data candidates (pseudo data) so as to oppose the classifier, and updating of the classifier with the generated OOD data candidates, it is possible to comprehensively generate OOD data, and enhance robustness of the machine learning model.
An example in which the training data Dtr={(xi, yi)}i is defined as yi to {−1, 1} and xi to N (yi, 0.3I2), yi=−1 is defined as class 0, and other data is defined as class 1 is now described.
Also, in
In
In the first embodiment described above, a technique for generating OOD data by optimization has been described, but embodiments are not limited to this. In a second embodiment, generation of OOD data is performed with the use of an optimized OOD generation model.
An OOD generation model 104 is a neural network having a weight W, receives an input of data z, and outputs OOD data G (z).
The neural network may be a hardware circuit, or may be a virtual network by software that connects between layers virtually constructed on a computer program by a processor 11 described later (see
The OOD generation model 104 is trained (subjected to machine learning) by the classifier 105 so as to generate samples that have a high certainty factor and are different from IND.
In the example illustrated in
The output G (z) of the OOD generation model 104 is input to each of a discriminator 106 and the classifier 105. In the discriminator 106, machine learning is performed so that OOD is also regarded as IND. In the classifier 105, machine learning (training) is performed so that OOD can be recognized while IND is correctly identified. The OOD generation model 104 may be referred to as a generator. Also, the generator may be denoted by reference sign G.
Data (OOD data) generated by the OOD generation model 104 is input to each of the classifier 105 and the discriminator 106. The OOD data generated by the OOD generation model 104 corresponds to the first data to be input to the discriminator 106. The OOD generation model 104 corresponds to the generator that generates the first data.
The information processing apparatus 1 of the second embodiment has functions as a training processing unit 200.
The training processing unit 200 performs training (machine learning) on the OOD generation model 104, using a training data set.
As illustrated in
The training data generation unit 201 generates a plurality of pieces of training data. The plurality of pieces of training data may be referred to as the training data set. The training data includes OOD source data (pseudo data) and dummy labels. In the training of the OOD generation model 104, the OOD source data is used as the data to be input to the OOD generation model 104, and the dummy labels are used as correct data.
The training data generation unit 201 may randomly generate the OOD source data (pseudo data) and the dummy labels for the OOD source data, respectively. The OOD source data may be denoted by reference sign {zi}i. The dummy labels may be denoted by reference sign {ti}i.
The parameter setting unit 202 trains the OOD generation model 104, using the training data in which the OOD source data is input data, and the dummy label is correct data. For example, the parameter setting unit 202 updates the weight W (training parameter) of the OOD generation model 104 so as to reduce the sum of the loss L(G(zi), ti), based on G(zi), which is the output obtained by inputting the OOD source data zi to the OOD data generation unit 100, the class obtained by inputting G(zi) to the classifier 105, and the dummy label ti which is correct data.
The parameter setting unit 202 may optimize the parameter by updating the parameters of the neural network in such a direction as to reduce the loss function that defines an error between a result of inference by the OOD generation model 104 for the training data and the correct data, using the gradient descent method, for example.
In the second embodiment, the OOD generation model 104 (generator G) is trained so as to oppose the classifier C.
The OOD generation model 104 receives an input of training data Dtr (labeled IND samples), a hyperparameter a, the number of OOD samples n, the number of optimization steps m, and learning rate Ir (see reference sign P01). The OOD generation model 104 merges OOD samples in each step, to output OOD data of a total of n×m samples (see reference sign P02).
In the initialization process, initialization of the classifier C and the discriminator D using the training data Dtr is performed, respectively (see reference sign P03).
Also, in the initialization process, the OOD generation model 104 (generator G) is initialized (see reference sign P04). The initial value of the weight W of the OOD generation model 104 is represented by W0.
The training data generation unit 201 randomly generates the OOD source data {zi}i and dummy labels {ti}i (see reference sign P05). The dummy labels {ti}i indicate the classes into which the pseudo data is to be classified.
The parameter setting unit 202 calculates the gradient ∇ of the weight of the OOD generation model 104 (generator G) with respect to the loss L(G(zi), ti) (see reference sign P06).
The parameter setting unit 202 updates the weight of the OOD generation model 104 (generator G), based on the calculated gradient ∇ of the weight (Ws+1←Ws−Ir×∇) (see reference sign P07).
After that, the output {G(zi)}i of the OOD generation model 104 is added to the set DOOD of OOD (DOOD←DOOD∪{G(zi)i}: see reference sign P08).
After that, the generated OOD data is added to the training data Dtr, and the classifier C is updated with these pieces of data. This update of the classifier C is performed by regularizing OOD data so that uniform outputs are made and conducting training (see reference sign P09).
Note that the discriminator D may be updated with the training data Dtr to which the generated OOD data (OOD candidate data) is added (see reference sign P010). The OOD data is also regarded as IND while machine learning is performed.
These processes denoted by reference signs P05 to P010 are repeatedly performed (repetitive execution) for all the steps of the optimization step number m.
After that, final adjustment (fine tuning) of the classifier C is performed with the training data to which the generated OOD data (OOD candidate data) is added (see reference sign P011).
In this manner, in the information processing apparatus 1 as an example of the second embodiment, actions and effects similar to those of the first embodiment can be achieved.
The information processing apparatus 1 is a computer, and includes, as its components, a processor 11, a memory 12, a storage device 13, a graphics processing device 14, an input interface 15, an optical drive device 16, a device coupling interface 17, and a network interface 18, for example. Those components 11 to 18 are designed to be able to communicate with one another via a bus 19.
The processor 11 is a control unit that controls the entire information processing apparatus 1. The processor 11 may be a multiprocessor. The processor 11 may be any one of a central processing unit (CPU), a micro processing unit (MPU), a digital signal processor (DSP), an application specific integrated circuit (ASIC), a programmable logic device (PLD), a field programmable gate array (FPGA), and a graphics processing unit (GPU), for example. Also, the processor 11 may be a combination of two or more kinds of components among a CPU, a MPU, a DSP, an ASIC, a PLD, a FPGA, and a GPU.
The processor 11 then executes a control program (a data generation program, not illustrated), to implement the functions as the OOD data generation unit 100 illustrated as an example in
Note that the information processing apparatus 1 implements the functions as the OOD data generation unit 100, by executing a program (a data generation program, or an OS program) recorded in a computer-readable non-transitory recording medium, for example. OS is an abbreviation for an operating system. Also, the information processing apparatus 1 implements the functions as the training processing unit 200, by executing a program (a machine learning program, or an OS program) recorded in a computer-readable non-transitory recording medium, for example.
The programs in which processing content to be executed by the information processing apparatus 1 is written may be recorded in various kinds of recording media. For example, the programs to be executed by the information processing apparatus 1 may be stored in the storage device 13. The processor 11 loads at least one of the programs in the storage device 13 into the memory 12, and executes the loaded program.
Also, the programs to be executed by the information processing apparatus 1 (processor 11) may be recorded in a non-transitory portable recording medium such as an optical disk 16a, a memory device 17a, or a memory card 17c. The programs stored in the portable recording medium may be executed after being installed into the storage device 13, under the control of the processor 11, for example. Also, the processor 11 may directly read a program from the portable recording medium, and execute the program.
The memory 12 is a storage memory including a read only memory (ROM) and a random access memory (RAM). The RAM of the memory 12 is used as the main storage unit of the information processing apparatus 1. The RAM temporarily stores at least one of the programs to be executed by the processor 11. Also, the memory 12 stores various kinds of data needed for processing by the processor 11.
The storage device 13 is a storage device such as a hard disk drive (HDD), a solid state drive (SSD), or a storage class memory (SCM), and stores various kinds of data. The storage device 13 is used as an auxiliary storage unit of the information processing apparatus 1.
The storage device 13 stores the OS program, control programs, and various kinds of data. The control programs include the data generation program and the machine learning program.
Note that a semiconductor memory device such as an SCM or a flash memory may be used as the auxiliary storage unit. Also, redundant arrays of inexpensive disks (RAID) may be formed with a plurality of storage devices 13.
The storage device 13 and the memory 12 may store OOD data generated by the OOD data generation unit 100, and various kinds of data and parameters generated by the OOD data candidate generation unit 101, the OOD data candidate update unit 102, and the classifier update unit 103 in the course of processing.
Also, the storage device 13 and the memory 12 may store OOD data generated by the OOD data generation unit 100, and various kinds of data and parameters generated by the OOD data candidate generation unit 101, the training data generation unit 201, and the parameter setting unit 202 in the course of processing.
The graphics processing device 14 is coupled to a monitor 14a. The graphics processing device 14 displays an image on a screen of the monitor 14a, in accordance with a command from the processor 11. Examples of the monitor 14a include a display device using a cathode ray tube (CRT), a liquid crystal display device, and the like.
The input interface 15 is coupled to a keyboard 15a and a mouse 15b. The input interface 15 transmits signals sent from the keyboard 15a and the mouse 15b to the processor 11. Note that the mouse 15b is an example of a pointing device, and some other pointing device may be used. Examples of other pointing devices include a touch panel, a tablet, a touch pad, a track ball, and the like.
The optical drive device 16 reads data recorded on the optical disk 16a, using laser light or the like. The optical disk 16a is a non-transitory portable recording medium in which data is recorded so as to be read by reflection of light. Examples of the optical disk 16a include a digital versatile disc (DVD), a DVD-RAM, a compact disc read only memory (CD-ROM), a CD-recordable (R)/rewritable (RW), and the like.
The device coupling interface 17 is a communication interface for coupling a peripheral device to the information processing apparatus 1. For example, the memory device 17a and a memory reader/writer 17b may be coupled to the device coupling interface 17. The memory device 17a is a non-transitory recording medium equipped with a function of communicating with the device coupling interface 17, and may be a universal serial bus (USB) memory, for example. The memory reader/writer 17b writes data into the memory card 17c, or reads data from the memory card 17c. The memory card 17c is a card-type non-transitory recording medium.
The network interface 18 is connected to a network. The network interface 18 transmits and receives data via the network. Other information processing apparatuses, communication devices, and the like may be connected to the network.
Further, the disclosed technology is not limited to the embodiments described above, and various modifications may be made and implemented in a range without departing from the scope of the embodiments. Each configuration and each process of the embodiments may be selected or omitted as needed, or may be appropriately combined.
For example, in the examples described in the above respective embodiments, the classifier performs two-class determination to classify input data into two classes, but embodiments are not limited to this. The classifier may perform multi-class determination for classifying input data into three or more classes, and may modify and perform the multi-class determination as appropriate.
Also, in the examples described in the above embodiments, data to be processed is image data. However, embodiments are not limited to this, and the above examples can be applied to various kinds of data such as tabular data and text data. For example, data (tabular data) in which physical information including height, weight, and the like is summarized in a tabular form may be used in conjunction with a technique for determining whether the subject is likely to contract a specific disease.
Furthermore, those skilled in the art can carry out or manufacture the embodiments according to the above disclosure.
All examples and conditional language provided herein are intended for the pedagogical purposes of aiding the reader in understanding the invention and the concepts contributed by the inventor to further the art, and are not to be construed as limitations to such specifically recited examples and conditions, nor does the organization of such examples in the specification relate to a showing of the superiority and inferiority of the invention. Although one or more embodiments of the present invention have been described in detail, it should be understood that the various changes, substitutions, and alterations could be made hereto without departing from the spirit and scope of the invention.
This application is a continuation application of International Application PCT/JP2021/048702 filed on Dec. 27, 2021 and designated the U.S., the entire contents of which are incorporated herein by reference.
Number | Date | Country | |
---|---|---|---|
Parent | PCT/JP2021/048702 | Dec 2021 | WO |
Child | 18739788 | US |