The embodiments relate generally to machine learning, and more specifically to systems and methods for improving efficiency of classification models.
Pre-trained language models may be used for text classification, e.g., given an input text, a classification label may be generated for the input sentence. Some pre-trained language models may be applied for zero-shot text classification, e.g., without being trained with training data associated with a particular classification label. Such zero-shot text classification models often rely on natural language inference (NLI) and next sentence prediction (NSP). For example, NLI is the task of determining whether a given “hypothesis” logically follows from a “premise,” e.g., by generating an output predicting whether the “hypothesis” is True, False or Neutral given the “premise.” NSP is a binary classification task predicting whether two sentences are consecutive or not. These models often employ cross-encoder architecture and infer by making a forward pass through the model for each label-text pair separately, which increases the computational cost to make inferences linearly in the number of labels. As a result, when the number of labels grows, the computational cost for text classification based on the set of labels may also grows accordingly.
Therefore, there is a need to improve the text classification efficiency.
In the figures, elements having the same designations have the same or similar functions.
As used herein, the term “network” may comprise any hardware or software-based framework that includes any artificial intelligence network or system, neural network, or system and/or any training or learning models implemented thereon or therewith.
As used herein, the term “module” may comprise hardware or software-based framework that performs one or more functions. In some embodiments, the module may be implemented on one or more neural networks.
Zero-shot classification models may make a prediction according to a classification label, even if samples of such label were not previously observed during training. Such models use a similarity score between text and labels mapped to common embedding space, and calculate text and label embeddings independently and make only one forward pass over the text resulting in a minimal increase in the computation. Some other existing models may incorporate label information when processing the text, or may use generative modeling to generate text given label embedding, or may use label embedding based attention over text. All of these models may require multiple passes over the text, increasing the computational cost.
NLI/NSP-based zero-shot classification model may make inferences by defining a representative hypothesis sentence for each class label and producing a score corresponding to every pair of input text (i.e., premise) and label representative hypothesis. To compute the score, the NLI/NSP model employs full self-attention over the premise and hypothesis sentences which requires recomputing the encoding for each premise and hypothesis separately. This increases the computational cost and the time to make inferences linearly in the number of target class labels.
In particular, with large transformer-based pre-trained language models, the model size and the number of target labels could drastically reduce the prediction efficiency, increasing the computation and inference time, and may significantly increase the carbon footprint of making predictions.
In view of the need for computationally efficient based zero-shot text classification, embodiments described herein provide a Conformal Predictor (CP) that reduces the number of likely target class labels CP. Specifically, the CP provides a model agnostic framework to generate a label set, instead of a single label prediction, within a pre-defined error rate. The CP employs a fast base classifier which may be used to filter out unlikely labels from the target label set, and thus restrict the number of probable target class labels while ensuring the candidate class labels set meets the pre-defined error rate.
For example, the fast base classifier may predict on a calibration dataset comprising samples corresponding to all classification labels. The predictions may then be used to “calibrate” a nonconformity threshold between a prediction and a ground-truth label. Thus, given a testing data point, the fast base classifier may be used to generate a subset of classification labels that satisfy the nonconformity threshold with a prediction of the testing data point. The subset of classification labels are then used with the larger NLI/NSP-based zero-shot text classification models to make the final prediction of the testing data point.
In one embodiment, a calibration dataset 202 containing data samples covering all classification labels from the original label set 105, {(x1, y1), (x2, y2), . . . , (xn, yn)} is input to the base classifier 220. In one implementation, the calibration dataset 202 may be obtained from annotated data. In another implementation, when human-labeled dataset is unavailable, the large zero-shot classification model 120 may be used to label samples for calibration. As the goal is to obtain a label set that contains the class label which si most probably according to the zero-shot classification model 120, human-labeled samples may not be necessary. Using model-predicted labels in the calibration data 202 may achieve the desired coverage.
In one embodiment, the base classifier 220 may generate corresponding predictions 204, e.g., ŷ1, ŷ2, . . . , ŷn, in response to the input x1, x2, . . . , xn from the calibration dataset 202. A measure of non-conformity 206, represented by s(xi, yi), i=1, n, may be used to describe the disagreement between the actual label yi and the prediction ŷi from the base classifier 220. Thus, an empirical quantile of scores s(x1, y1), . . . , s(xn, yn) may be computed as {circumflex over (q)}=┌(n+1)(1−α)┐/n.
In one implementation, the base classifier 220 is chosen to be computationally efficient when compared to the original zero-shot classification model 120. For example, the zero-shot classification model 120 may be an intent and/or topic classification model. Then the base classifier 220 may be built based on token overlap (referred to as “CP-Token”). For each target class label (yk ∈ {y1, . . . , yK}), the list of representative tokens (Cwk) that include all tokens in the calibration data samples corresponding to that class. Then, the non-conformity score may be computed as the percentage of common tokens between Cwk and the input text (x). Given #x defines the unique tokens in the input text x, the token overlap based non-conformity score is computed as:
In another implementation, the base classifier 220 may be built based on cosine similarity (referred to as “CP-Glove”). For example, the non-conformity score computed based on token overlap sometimes may suffer from sparsity unless a large representative words set for each target class label is used. Therefore, another way to compute the non-conformity score may be based on the cosine distance between bag-of-words (BoW) representation of target label description (Cwk) and input text (xE). Static GloVe embeddings may be used to compute the BoW representations for labels. In this way, the non-conformity scores are computed as:
In another implementation, the base classifier 220 may be a task-specific base classifier to generate label sets of smaller size. For example, a distilled BERT-base model that is fine-tuned using data labeled from predicted labels by the original zero-shot classification model 120 may serve as the base classifier 220. The negative of class logits may be used as the non-conformity scores.
In another implementation, the base classifier 220 may be another parameter-efficient NLI zero-shot model (different from the zero-shot classification model 120), such as a distil-Roberta-base-NLI model. While NLI-based zero-shot models may be computationally expensive, they may serve as a good base classifier for relatively larger pre-trained language models, with many parameters, or when there are many target class labels (e.g., 64 labels in HWU64). The non-conformity score may be computed as the negative entailment probability of each class.
Next, for a new exchangeable test data point 208, xtest, the base classifier 220 may in turn generate a prediction ŷtest. Thus, scores 212, s(xtest, y1), . . . , s(xtest, yK) may in turn be computed for the K actual target labels using the prediction ŷtest generated by the base classifier 220. The quantile {circumflex over (q)} is then applied to the set of scores 212 as a threshold to generate the reduced label set 216, e.g., Γα={yk: s(xtest, yk)<{circumflex over (q)}}, i.e., the classes corresponding to which the non-conformity score is lower than the q. The reduced label set Γα216 is finally used with the large zero-shot model 120 to predict the final class label of xtest 208 without reducing the coverage beyond the pre-defined error rate a.
In one embodiment, in order to reduce the average prediction set size while maintaining the computational efficiency, an ensemble of class label descriptions (i.e. hypothesis for NLI, nest sentence for NSP) may be used with the base CP classifier 220. The ensemble of descriptions helps to capture the surface-level variations in language and reduce average prediction set size without increasing the complexity of the base CP classifier. For cosine-similarity-based non-conformity scores, the mean of embeddings of all the descriptions corresponding to a given label may be incorporated during computation. Alternatively, all cosine similarities (or logits/probabilities for distilled zero-shot model) corresponding to all the descriptions of a given class may be computed and then their average is taken as the final non-conformity score.
In addition, beyond the settings for zero/few-shot classification described above, an ensemble of prompts or ensemble of verbalizer may be used in prompt-based classification models to reduce average prediction set size of CP. For example, when using prompts to extract knowledge from pre-trained language models, an ensemble (maximum or mean) of logits (or softmax probabilities) corresponding to different prompts (e.g., France is on the continent of [MASK], The continent of France is [MASK]) may be used for reducing average prediction set size. Likewise, for classification tasks, verbalizers (This text is about [Science]/[Technology]/[Mathematics]/etc.) may be ensembled while keeping the template fixed.
In some implementations, the availability of text from the target task for calibration (and training the base classifier 220) is generally assumed. While a zero-shot base classifier (e.g., CP-Token/Glove/Distil) can be used, in practice a few samples may be needed for the CP calibration. To tackle this issue, a small sized-calibration set with low chosen a may improve zero-shot classification efficiency without dropping the performance. Also, transferability of calibration set across such datasets is significant, e.g., when data samples corresponding to the actual zero-shot task is not available, data from another classification task (or a set of tasks) can be used for calibration, provided that the a is set to a low value.
It is noted that the CP framework 100 described in relation to
For another example, the zero-shot classification model 120 may be an image classification model, and the base classifier 220 may be a more computationally efficient image classifier accordingly.
Memory 320 may be used to store software executed by computing device 300 and/or one or more data structures used during operation of computing device 300. Memory 320 may include one or more types of machine-readable media. Some common forms of machine-readable media may include floppy disk, flexible disk, hard disk, magnetic tape, any other magnetic medium, CD-ROM, any other optical medium, punch cards, paper tape, any other physical medium with patterns of holes, RAM, PROM, EPROM, FLASH-EPROM, any other memory chip, or cartridge, and/or any other medium from which a processor or computer is adapted to read.
Processor 310 and/or memory 320 may be arranged in any suitable physical arrangement. In some embodiments, processor 310 and/or memory 320 may be implemented on a same board, in a same package (e.g., system-in-package), on a same chip (e.g., system-on-chip), and/or the like. In some embodiments, processor 310 and/or memory 320 may include distributed, virtualized, and/or containerized computing resources. Consistent with such embodiments, processor 310 and/or memory 320 may be located in one or more data centers and/or cloud computing facilities.
In some examples, memory 320 may include non-transitory, tangible, machine readable media that includes executable code that when run by one or more processors (e.g., processor 310) may cause the one or more processors to perform the methods described in further detail herein. For example, as shown, memory 320 includes instructions for an efficient zero-shot classification module 330 that may be used to implement and/or emulate the systems and models, and/or to implement any of the methods described further herein. The efficient zero-shot classification module 330 may receive an input 340, e.g., a text input via the data interface 315 and generate an output 350 (e.g., a predicted classification label of the input text 340).
In one embodiment, the data interface 315 may be a user interface that receives a user submitted input, e.g., a voice input, a user entered text, and/or the like. In another embodiment, the data interface 315 may be a communication interface that receives training data from a remote data server (e.g., see data vendor servers 445, 470 and 480 in
In some embodiments, the efficient zero-shot classification module 330 may include a conformal prediction module 331 (e.g., similar to 110 in
In one embodiment, the efficient zero-shot classification module 330 and its submodules 331-332 may be implemented by hardware, software and/or a combination thereof.
Some examples of computing devices, such as computing device 300 may include non-transitory, tangible, machine readable media that include executable code that when run by one or more processors (e.g., processor 310) may cause the one or more processors to perform the processes of method. Some common forms of machine-readable media that may include the processes of method are, for example, floppy disk, flexible disk, hard disk, magnetic tape, any other magnetic medium, CD-ROM, any other optical medium, punch cards, paper tape, any other physical medium with patterns of holes, RAM, PROM, EPROM, FLASH-EPROM, any other memory chip or cartridge, and/or any other medium from which a processor or computer is adapted to read.
The user device 410, data vendor servers 445, 470 and 480, and the server 430 may communicate with each other over a network 460. User device 410 may be utilized by a user 440 (e.g., a driver, a system admin, etc.) to access the various features available for user device 410, which may include processes and/or applications associated with the server 430 to receive an output data anomaly report.
User device 410, data vendor server 445, and the server 430 may each include one or more processors, memories, and other appropriate components for executing instructions such as program code and/or data stored on one or more computer readable mediums to implement the various applications, data, and steps described herein. For example, such instructions may be stored in one or more computer readable media such as memories or data storage devices internal and/or external to various components of system 400, and/or accessible over network 460.
User device 410 may be implemented as a communication device that may utilize appropriate hardware and software configured for wired and/or wireless communication with data vendor server 445 and/or the server 430. For example, in one embodiment, user device 410 may be implemented as an autonomous driving vehicle, a personal computer (PC), a smart phone, laptop/tablet computer, wristwatch with appropriate computer hardware resources, eyeglasses with appropriate computer hardware (e.g., GOOGLE GLASS®), other type of wearable computing device, implantable communication devices, and/or other types of computing devices capable of transmitting and/or receiving data, such as an IPAD® from APPLE®. Although only one communication device is shown, a plurality of communication devices may function similarly.
User device 410 of
In various embodiments, user device 410 includes other applications 416 as may be desired in particular embodiments to provide features to user device 410. For example, other applications 416 may include security applications for implementing client-side security features, programmatic client applications for interfacing with appropriate application programming interfaces (APIs) over network 460, or other types of applications. Other applications 416 may also include communication applications, such as email, texting, voice, social networking, and IM applications that allow a user to send and receive emails, calls, texts, and other notifications through network 460. For example, the other application 416 may be an email or instant messaging application that receives a prediction result message from the server 430. Other applications 416 may include device interfaces and other display modules that may receive input and/or output information. For example, other applications 416 may contain software programs for asset management, executable by a processor, including a graphical user interface (GUI) configured to provide an interface to the user 440 to view the output.
User device 410 may further include database 418 stored in a transitory and/or non-transitory memory of user device 410, which may store various applications and data and be utilized during execution of various modules of user device 410. Database 418 may store user profile relating to the user 440, predictions previously viewed or saved by the user 440, historical data received from the server 430, and/or the like. In some embodiments, database 418 may be local to user device 410. However, in other embodiments, database 418 may be external to user device 410 and accessible by user device 410, including cloud storage systems and/or databases that are accessible over network 460.
User device 410 includes at least one network interface component 419 adapted to communicate with data vendor server 445 and/or the server 430. In various embodiments, network interface component 419 may include a DSL (e.g., Digital Subscriber Line) modem, a PSTN (Public Switched Telephone Network) modem, an Ethernet device, a broadband device, a satellite device and/or various other types of wired and/or wireless network communication devices including microwave, radio frequency, infrared, Bluetooth, and near field communication devices.
Data vendor server 445 may correspond to a server that hosts one or more of the databases 403a-n (or collectively referred to as 403) to provide training datasets including the calibration dataset 202 to the server 430. The database 403 may be implemented by one or more relational database, distributed databases, cloud databases, and/or the like.
The data vendor server 445 includes at least one network interface component 426 adapted to communicate with user device 410 and/or the server 430. In various embodiments, network interface component 426 may include a DSL (e.g., Digital Subscriber Line) modem, a PSTN (Public Switched Telephone Network) modem, an Ethernet device, a broadband device, a satellite device and/or various other types of wired and/or wireless network communication devices including microwave, radio frequency, infrared, Bluetooth, and near field communication devices. For example, in one implementation, the data vendor server 445 may send asset information from the database 403, via the network interface 426, to the server 430.
The server 430 may be housed with the efficient zero-shot classification module 330 and its submodules described in
The database 432 may be stored in a transitory and/or non-transitory memory of the server 430. In one implementation, the database 432 may store data obtained from the data vendor server 445. In one implementation, the database 432 may store parameters of the efficient zero-shot classification model 330. In one implementation, the database 432 may store previously generated reduced label set, and the corresponding input feature vectors.
In some embodiments, database 432 may be local to the server 430. However, in other embodiments, database 432 may be external to the server 430 and accessible by the server 430, including cloud storage systems and/or databases that are accessible over network 460.
The server 430 includes at least one network interface component 433 adapted to communicate with user device 410 and/or data vendor servers 445, 470 or 480 over network 460. In various embodiments, network interface component 433 may comprise a DSL (e.g., Digital Subscriber Line) modem, a PSTN (Public Switched Telephone Network) modem, an Ethernet device, a broadband device, a satellite device and/or various other types of wired and/or wireless network communication devices including microwave, radio frequency (RF), and infrared (IR) communication devices.
Network 460 may be implemented as a single network or a combination of multiple networks. For example, in various embodiments, network 460 may include the Internet or one or more intranets, landline networks, wireless networks, and/or other appropriate types of networks. Thus, network 460 may correspond to small scale communication networks, such as a private or local area network, or a larger scale network, such as a wide area network or the Internet, accessible by the various components of system 400.
At step 502, method 500 performs receiving, via a data interface (e.g., 315 in
In some implementations, a given label from the set of classification labels comprises an ensemble of class descriptions, including any of a hypothesis in natural language inference, or a next sentence for next sentence prediction. Or given label from the set of classification labels comprises an ensemble of prompts or an ensemble of verbalizer when the zero-shot classification model is a prompt-based classification model.
At step 504, method 500 performs generating, via a base classifier model (e.g., 220 in
At step 506, method 500 performs computing a first set of non-conformity scores (e.g., scores 206 in
At step 508, method 500 performs computing a non-conformity threshold based on the first set of non-conformity scores (e.g., scores 206 in
At step 510, method 500 performs generating, by the base classifier model (e.g., 220 in
At step 512, method 500 performs generating a second set of non-conformity scores (e.g., scores 212 in
At step 514, method 500 performs determining a reduced set of classification labels (e.g., reduced label set 216 in
At step 516, method 500 performs generating, via the zero-shot classification model (e.g., model 120 in
It is noted that method 500 is discussed in relation to text classification for illustrative purpose only. Method 500 is model-agnostic, and may be applied to different types of classification models and/or tasks.
The CP-based framework described in
For example, the entire training set of intent datasets and 5000 samples from the validation set of topic datasets are used to calibrate CP-Token, CP-Glove and CP-Distil. For CP-CLS, the entire training set of intent datasets and 2500 samples from the validation set of topic datasets to train the base classifier, and the entire validation set of intent datasets and 2500 samples from the validation set of topic datasets for calibration.
For the CP-Distil base classifier, “cross-encoder/nli-distilroberta-base” model from the Hugging Face hub is used. The text is only to be classified (without labels) for calibrating base classifiers and training CP-CLS base classifier. The zero-shot classification model is used to label the corresponding training and calibration samples.
The results in
CP reduces the average number of labels for the 0shot model. It is observed that stronger base classifiers (CLS and Distil) provide lower ALS size for the same empirical (or nominal) coverage (
A simpler and efficient CP base classifier may reduce the inference time the most. It is observed that CP-Token achieves the best inference time with the NLI model on ATIS, SNIPS and AG's news datasets, and with the NSP model on ATIS and AG's news datasets. On the other hand, it achieves the lowest ALS size for both models only on the ATIS dataset. Minimal complexity for calculating token-overlap adds negligible overhead to the 0shot model, thus, achieving the best speed up despite higher ALS size in several cases.
Also, CP base classifier needs to be computationally inexpensive. CP-Distil improves inference time for the NLI model on all datasets but fails to do so for the NSP model, despite reduced ALS size. This ineffectiveness is explained by the comparable inference time for the base (distil-nli) classifier and the 0shot NSP model. When building the CP, it is imperative to select a base classifier which is very economical relative to the 0shot model. CP improves efficiency the most on the dataset with many labels. It is observed that the maximum speed up on HWU64 and ATIS datasets. This is unsurprising given the relatively higher number of possible target labels for both datasets, emphasizing the benefit of CP for tasks with many target labels.
CP performs comparable to the 0shot model. CP-based label filtering retains the performance of the corresponding models that use a full label set. Among the four base classifiers, CP-Token performs the worst (−0.46% absolute drop) and CP-Distil performs the best (+0.31% absolute gain) on the average accuracy. It is noteworthy that the accuracy increases in many cases, suggesting that pruning label space using a CP may remove noisy labels and boost the performance.
This description and the accompanying drawings that illustrate inventive aspects, embodiments, implementations, or applications should not be taken as limiting. Various mechanical, compositional, structural, electrical, and operational changes may be made without departing from the spirit and scope of this description and the claims. In some instances, well-known circuits, structures, or techniques have not been shown or described in detail in order not to obscure the embodiments of this disclosure. Like numbers in two or more figures represent the same or similar elements.
In this description, specific details are set forth describing some embodiments consistent with the present disclosure. Numerous specific details are set forth in order to provide a thorough understanding of the embodiments. It will be apparent, however, to one skilled in the art that some embodiments may be practiced without some or all of these specific details. The specific embodiments disclosed herein are meant to be illustrative but not limiting. One skilled in the art may realize other elements that, although not specifically described here, are within the scope and the spirit of this disclosure. In addition, to avoid unnecessary repetition, one or more features shown and described in association with one embodiment may be incorporated into other embodiments unless specifically described otherwise or if the one or more features would make an embodiment non-functional.
Although illustrative embodiments have been shown and described, a wide range of modification, change and substitution is contemplated in the foregoing disclosure and in some instances, some features of the embodiments may be employed without a corresponding use of other features. One of ordinary skill in the art would recognize many variations, alternatives, and modifications. Thus, the scope of the invention should be limited only by the following claims, and it is appropriate that the claims be construed broadly and, in a manner, consistent with the scope of the embodiments disclosed herein.
This application is a nonprovisional of and claims priority to U.S. nonprovisional application no. 63/331,135, filed Apr. 14, 2022, which is hereby expressly incorporated by reference herein in its entirety.
Number | Date | Country | |
---|---|---|---|
63331135 | Apr 2022 | US |