This application claims the priority benefit of Taiwan application serial no. 110138818, filed on Oct. 20, 2021. The entirety of the above-mentioned patent application is hereby incorporated by reference herein and made a part of this specification.
This disclosure relates to an electronic device and a method adaptable for training neural network model.
Most of the existing supervised machine learning is to manually generate labeled data, and then use the labeled data to train a machine learning model (for example, a deep learning model). In order to increase the accuracy of the machine learning model, it is often required to collect a large amount of labeled data. However, the method of manually generating labeled data not only consumes time and human resources, but also is likely to cause data to be erroneously labeled due to human error, leading to reduction of effectiveness of the machine learning model. In addition, in vertical applications (such as industrial vision, medicine, etc.), it is often difficult to collect recognized target images (such as flawed images, symptom images, etc.), which increases the difficulty of introducing machine learning. Therefore, how to reduce the amount of labeled data that needs to be manually generated without reducing the performance of the machine learning model is one of the important issues in this field.
The disclosure provides an electronic device and a method adaptable for training a neural network model, which can use a small amount of artificially labeled data to train a neural network model with high performance.
An electronic device adaptable for training a neural network model disclosed in the disclosure includes a storage medium and a processor. The storage medium stores a first neural network model. The processor is coupled to the storage medium, and the processor is configured to: obtain a first pseudo-labeled data; input the first pseudo-labeled data into the first neural network model to obtain a second pseudo-labeled data; determine whether a second pseudo-label corresponding to the second pseudo-labeled data matches a first pseudo-label corresponding to the first pseudo-labeled data; in response to that the second pseudo-label matches the first pseudo-label, add the second pseudo-labeled data to a pseudo-labeled dataset; and train the first neural network model according to the pseudo-labeled dataset.
A method for training a neural network model in the disclosure includes: obtaining a first neural network model and a first pseudo-labeled data; inputting the first pseudo-labeled data into the first neural network model to obtain a second pseudo-labeled data; determining whether a second pseudo-label corresponding to the second pseudo-labeled data matches a first pseudo-label corresponding to the first pseudo-labeled data; in response to that the second pseudo-label matches the first pseudo-label, adding the second pseudo-labeled data to a pseudo-labeled dataset; and training the first neural network model according to the pseudo-labeled dataset.
The processor 110 is, for example, a central processing unit (CPU), or other programmable general-purpose or special-purpose micro control unit (MCU), microprocessor, digital signal processor (DSP), programmable controller, application specific integrated circuit (ASIC), graphics processing unit (GPU), image signal processor (ISP), image processing unit (IPU), arithmetic logic unit (ALU), complex programmable logic device (CPLD), field programmable gate array (FPGA) or other similar components or a combination of the above components. The processor 110 may be coupled to the storage medium 120 and the transceiver 130, and access and execute multiple modules and various application programs stored in the storage medium 120.
The storage medium 120 is, for example, any type of fixed or removable random access memory (RAM), read-only memory (ROM), flash memory, hard disk (HDD), solid state drive (SSD) or similar components or a combination of the above components, and adapted to store multiple modules or various application programs that can be executed by the processor 110. In this embodiment, the storage medium 120 can store a teacher model (or referred to as “second neural network model”) 121, a student model (or referred to as “first neural network model”) 122, and a final neural network model 123, etc. The functions of multiple models will be explained later.
The transceiver 130 transmits and receives signals in a wireless or wired manner. The electronic device 100 can receive data or output data through the transceiver 130.
After obtaining the labeled dataset Li, the processor 110 may train the neural network architecture 200 based on the labeled dataset Li to obtain the teacher model 121, and the teacher model 121 may include, but is not limited to, a convolution neural network (CNN) model. The neural network architecture 200 may include information such as the type of neural network (for example, convolution neural network), the weight configuration method of the neural network, the loss function of the neural network, or the hyperparameters of the neural network, etc. The disclosure is not limited thereto. The processor 110 may train the neural network architecture 200 according to supervised learning (SL) to obtain the teacher model 121.
After completing the training of the teacher model 121, the processor 110 can input the unlabeled dataset U to the teacher model 121 to obtain a highly trusted (completely trusted) pseudo-labeled dataset Ph and a partially trusted pseudo-labeled dataset Pi, and i is the index of the partially trusted pseudo-labeled dataset. The highly trusted pseudo-labeled dataset Ph or the partially trusted pseudo-labeled dataset Pi can contain one or more pseudo-labeled data, respectively.
In an embodiment, the processor 110 may determine that the unlabeled data in the unlabeled dataset U should be allocated to the highly trusted pseudo-labeled dataset Ph or the partially trusted pseudo-labeled dataset Pi according to a confidence threshold. Specifically, the processor 110 may input the unlabeled data to the teacher model 121 to generate a probability vector, and the probability vector may include one or more probabilities corresponding to one or more labels, respectively. The processor 110 may allocate the unlabeled data according to the probability vector and the confidence threshold. The processor 110 may add the unlabeled data to the highly trusted pseudo-labeled dataset Ph in response to the maximum probability in the probability vector being greater than the confidence threshold. The processor 110 may add the unlabeled data to the partially trusted pseudo-labeled dataset Pi in response to the maximum probability in the probability vector being less than or equal to the confidence threshold. In the highly trusted pseudo-labeled dataset Ph, the labels of pseudo-labeled data are more trusted, so these pseudo-labeled data do not need to be re-checked whether the labels are correct. Relatively speaking, in the partially trusted pseudo-labeled dataset Pi, the labels of pseudo-labeled data are less trusted, so these pseudo-labeled data need to be re-checked whether the labels are correct.
For example, the processor 110 may input the unlabeled data in the unlabeled dataset U into the teacher model 121 to generate a probability vector [p1 p2 p3], the probability p1 corresponds to the first type of label, the probability p2 corresponds to the second type of label, and the probability p3 corresponds to the third label. If the probability p2 is greater than the probability p1 and greater than the probability p3, it means that the teacher model 121 recognizes the unlabeled data as data corresponding to the second type of label. Accordingly, the processor 110 can determine whether the probability p2 (i.e., the maximum probability) is greater than the confidence threshold. If the probability p2 is greater than the confidence threshold, the processor 110 may add the unlabeled data to the highly trusted pseudo-labeled dataset Ph. If the probability p2 is less than or equal to the confidence threshold, the processor 110 may add the unlabeled data to the partially trusted pseudo-labeled dataset Pi.
After completing the training of the student model 122, the processor 110 may input the pseudo-labeled data (or referred to as “third pseudo-labeled data”) D1 in the partially trusted pseudo-labeled dataset Pi to the student model 122 to generate pseudo-labeled data (or referred to as “fourth pseudo-labeled data”) D2. Then, the processor 110 can determine whether the pseudo-labeled data D2 is trusted or not trusted.
If the pseudo-labeled data D2 is trusted, the processor 110 may update the partially trusted pseudo-labeled dataset Pi according to the pseudo-labeled data D2. Specifically, the processor 110 may add the pseudo-labeled data D2 to the partially trusted pseudo-labeled dataset Pi+1. After determining whether all pseudo-labeled data in the partially trusted pseudo-labeled dataset Pi is trusted, the processor 110 may obtain the final partially trusted pseudo-labeled dataset Pi+1. The processor 110 may use the partially trusted pseudo-labeled dataset Pi+1 to replace the partially trusted pseudo-labeled dataset Pi, thereby updating the partially trusted pseudo-labeled dataset Pi.
On the other hand, if the pseudo-labeled data D2 is not trusted, the processor 110 may output the pseudo-labeled data D2 for the user to manually mark the pseudo-labeled data D2, thereby generating the labeled data D3 (or referred to as “fourth labeled data”). The processor 110 may add the labeled data D3 to the labeled dataset Lx. After determining whether all the pseudo-labeled data in the partially trusted pseudo-labeled dataset Pi is trusted, the processor 110 may obtain the final labeled dataset Lx. The processor 110 may add the labeled data in the final labeled dataset Lx to the labeled dataset Li, so as to update the labeled dataset Li.
The processor 110 may determine whether the pseudo-labeled data D2 is trusted according to whether the pseudo-labeled data D2 and the pseudo-labeled data D1 are matched. If the pseudo-label of the pseudo-labeled data D2 (or referred to as “fourth pseudo-label”) matches or is the same as the pseudo-label of the pseudo-labeled data D1 (or referred to as “third pseudo-label”), it means that the recognition result of the teacher model 121 is the same as the recognition result of the student model 122. Accordingly, the processor 110 can determine that the pseudo-labeled data D2 is trusted. If the pseudo-label of the pseudo-labeled data D2 does not match or is not the same as the pseudo-label of the pseudo-labeled data D1, it means that the recognition result of the teacher model 121 is different from the recognition result of the student model 122. Accordingly, the processor 110 can determine that the pseudo-labeled data D2 is not trusted.
The processor 110 may repeatedly perform the process shown in
On the other hand, the processor 110 may input the pseudo-labeled data (or referred to as “first pseudo-labeled data”) B1 in the partially trusted pseudo-labeled dataset P into the neural network model 400 to obtain the pseudo-labeled data (or referred to as “second pseudo-labeled data”) B2, and the partially trusted pseudo-labeled dataset P is, for example, the partially trusted pseudo-labeled dataset Pi as shown in
After obtaining the pseudo-labeled data B2, the processor 110 may perform a threshold check on the pseudo-labeled data B2, and determine whether the pseudo-labeled data B2 passes the threshold check. If the pseudo-labeled data B2 passes the threshold check, the processor 110 may further determine whether the pseudo-labeled data B2 matches the pseudo-labeled data B1. If the pseudo-labeled data B2 fails the threshold check, the processor 110 may ignore the pseudo-labeled data B2, so as not to add the pseudo-labeled data B2 to the pseudo-labeled dataset Y, and the pseudo-labeled dataset Y can be used to train or update the neural network model 400. In other words, the ignored pseudo-labeled data B2 will not be used to train or update the neural network model 400.
Specifically, the pseudo-labeled data B2 may include a probability vector. The processor 110 may perform a threshold check according to the probability vector. In an embodiment, the processor 110 may determine that the pseudo-labeled data B2 passes the threshold check in response to the maximum probability in the probability vector being greater than the probability threshold α. The processor 110 may determine that the pseudo-labeled data B2 fails the threshold check in response to the maximum probability in the probability vector being less than or equal to the probability threshold α. For example, the pseudo-labeled data B2 may include a probability vector [p11 p12 p13], and the probability p11 corresponds to the first type of label, the probability p12 corresponds to the second type of label, and the probability p13 corresponds to the third type of label. If the probability p12 is greater than the probability p11 and greater than the probability p13, the processor 110 may determine whether the probability p12 (i.e., the maximum probability) is greater than the probability threshold α. If the probability p12 is greater than the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 passes the threshold check. If the probability p12 is less than or equal to the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 fails the threshold check.
The neural network model 400 may include one or more sub-neural network models.
In an embodiment, the processor 110 may calculate the average probability of the first maximum probability in the first probability vector of the pseudo-labeled data B21 and the second maximum probability in the second probability vector of the pseudo-labeled data B22. If the average probability is greater than the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 passes the threshold check. If the average probability is less than or equal to the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 fails the threshold check. For example, suppose that pseudo-labeled data B21 can include a first probability vector [p21 p22 p23], and the pseudo-labeled data B22 can include a second probability vector [p31 p32 p33], the probability p22 is greater than the probability p21 and greater than the probability p23, and the probability p32 is greater than the probability p31 and greater than the probability p33. The processor 110 can calculate the average of the probability p22 and the probability p32. If the average of the probability p22 and the probability p32 is greater than the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 passes the threshold check. If the average of the probability p22 and the probability p32 is less than or equal to the probability threshold α, the processor 110 may determine that the pseudo-labeled data B2 fails the threshold check.
In an embodiment, the processor 110 may determine that the pseudo-labeled data B2 passes the threshold check in response to the first maximum probability in the first probability vector of the pseudo-labeled data B21 being greater than the probability threshold α and the second maximum probability in the second probability vector of the pseudo-labeled data B22 being greater than the probability threshold α. The processor 110 may determine that the pseudo-labeled data B2 fails the threshold check in response to at least one of the first maximum probability or the second maximum probability being less than or equal to the probability threshold α. For example, suppose that the pseudo-labeled data B21 can include a first probability vector [p21 p22 p23], and the pseudo-labeled data B22 can include a second probability vector [p31 p32 p33], the probability p22 is greater than the probability p21 and greater than the probability p23, and the probability p32 is greater than the probability p31 and greater than the probability p33. The processor 110 may determine that the pseudo-labeled data B2 passes the threshold check in response to the probability p22 and the probability p32 both being greater than the probability threshold α. The processor 110 may determine that the pseudo-labeled data B2 fails the threshold check in response to at least one of the probability p22 or the probability p32 being less than or equal to the probability threshold α.
Back to
Referring to
After obtaining the cross-entropy loss HPL and the cross-entropy loss HL, the processor 110 may obtain a loss function LF as shown in equation (1), and β is the loss weight. The processor 110 can train or update the neural network model 400 according to the loss function LF and the pseudo-labeled dataset Y. The processor 110 may repeatedly perform the process shown in
LF=H
L
+βH
PL (1)
In an embodiment, the processor 110 may train the final neural network model 123 according to supervised learning. In an embodiment, the processor 110 may train the final neural network model 123 according to the adaptive matching training method shown in
On the other hand, when the second iteration of the process shown in
In summary, the electronic device disclosed in the present disclosure can train a teacher model according to a small amount of manually generated labeled data based on a supervised learning algorithm, and then use the teacher model to mark a large amount of unlabeled data to generate pseudo-labeled data. The electronic device can train or update the student model according to the artificial labeled data and pseudo-labeled data based on the adaptive matching algorithm, so as to improve the student model's ability to recognize pseudo-labeled data. The electronic device can use the student model to determine whether the pseudo-label of the pseudo-labeled data is trusted. If the pseudo-label is not trusted, the electronic device can instruct the user to manually determine the correct label of the pseudo-labeled data. In short, the electronic device can select a small amount of pseudo-labeled data that needs to be manually checked from multiple pseudo-labeled data, and the pseudo-labels of other pseudo-labeled data can be regarded as correct labels. The user can train a neural network model with high performance based on the pseudo-labeled dataset generated by the method in the disclosure.
Number | Date | Country | Kind |
---|---|---|---|
110138818 | Oct 2021 | TW | national |