Text classification challenge with extra-small datasets: Fine-tuning versus ChatGPT

Author:Murphy  |  View: 20618  |  Time: 2025-03-23 18:11:05
Photo by Debby Hudson on Unsplash

The Toloka ML team continually researches and compares different approaches to text classification under various conditions. Here we present another one of our experiments on the performance of NLP models when trained on extra-small datasets.

Previously, we provided a brief overview of potential solutions and compared classical models with large language models (LLMs) for a specific text classification task. However, those comparisons were based on a "regular" dataset that contained enough data points to build a reliable classifier. In real-world scenarios, you may encounter situations where limited data is available or human labeling hasn't been carried out.

Intuitively, LLMs such as GPT-3 or ChatGPT might outperform smaller models due to their extensive "knowledge". To investigate this hypothesis, we created an artificially small dataset by extracting a portion of a larger one and compared several approaches. We fine-tuned the RoBERTa base model, employed ChatGPT for few-shot classification, and fine-tuned the GPT-3 Babbage model.

The dataset

To evaluate the comprehension capabilities of various models, we selected a multiclass dataset consisting of scientific article abstracts. The task was to determine each article's domain.

We opted for the WOS-11967 [1] dataset, which contains 11,967 documents with 35 categories that include seven parent categories: medical, psychology, computer science, biochemistry, electrical engineering, civil sciences, and mechanical engineering. We sampled 10,000 data points and focused solely on the parent categories for our analysis.

While the dataset was not perfectly balanced, the class distribution was reasonably proportional. Therefore, satisfactory results could potentially be achieved across all classes. The class distribution is illustrated below.

The class distribution of the sample of the WOS-11967 dataset

Upon manual analysis, we found that determining the domain of some abstracts was relatively straightforward, while in other cases, the task became more challenging. For instance, computer science articles may discuss mathematical topics, or psychology articles might contain medical or biochemical terms and abbreviations, making it difficult to distinguish them from biochemistry or medical domains. The abstracts also varied significantly in length, with a mean of 274 tokens (ChatGPT tokens) and a standard deviation of 115 tokens.

To simulate scenarios involving extra-small datasets, we performed a train-test split on the corpora and allocated a small number of samples to the training set. We repeated this process three times with different training set sizes to evaluate any changes in performance in the models based on the available training data. We created three splits for our experiment: WOS-11967-s200 (200 samples in the training set, 9,800 samples in the test set), WOS-11967-s500 (500 / 9,500), and WOS-11967-s2000 (2,000 / 8,000).

Now, let's take a look at the results obtained using different models to tackle these problems.

Regular fine-tuning with RoBERTa

For our baseline, we selected the RoBERTa base model [2] and fine-tuned it on the three datasets mentioned earlier. We used the same hyperparameter configuration for each run (a batch size of 32, a learning rate of 3e-5, a linear scheduler with warmup, and a 256-token window), along with early stopping to prevent overfitting.

We obtained the following results:

The data shows that 200 samples are insufficient when it comes to extracting all the necessary patterns and information required to accurately classify the abstracts. The lower macro-average F1 score also indicates that the model underperforms on under-represented classes like mechanical engineering. This suggests that it's not enough to have only a few samples from a particular class.

As expected, the model's performance improved as the amount of available data increased – ultimately resulting in robust performance for multiclass classification across seven classes.

Few-shot with ChatGPT

The second approach we explored was few-shot classification using ChatGPT. This method differs significantly from traditional classification as it doesn't involve training a model per se. Instead, we engineered the input prompt to achieve optimal performance.

However, it was impossible to feed all 200 samples into the model due to its 4096-token context size limit. Given the measurements above, we could only present around 14 abstracts to the model. That number was further reduced when considering the tokens used for instructions and delimiters.

Initially, we employed the "system" role for instructions and provided a single example per class to guide the model's response. We simplified the class names to single tokens while retaining their meaning. This made it easier for the model to select the appropriate category and limit the output to a single token. For instance, "Biochemistry" became "Bio," and "Computer Science" became "Computer." Additionally, we restricted the number of tokens generated by providing a list of classes to choose from and instructing the model to return the "Unknown" token if it was unsure about the category.

Overall, ‌performance with this method was inferior compared to the RoBERTa model trained on just 200 samples. We noticed that the model's classification ability heavily depended on the supplied prompt. Modifying a single sentence could either improve or worsen the metrics. In some cases, ChatGPT missed categories despite explicit instructions not to do so (which could be a drawback of how we formulated our prompt).

In a few fringe cases, it produced categories not listed in the instructions, but described the article domains, such as "Math" or "Chemistry". It's unclear whether these flaws should be attributed to the model or the dataset. However, according to the validation set, these categories can be corrected using simple rules like changing all instances of "Math" to "Computer".

To improve metrics, we tried to use as much data as possible. Since we still couldn't feed all 200 samples into the model, we devised a two-stage process:

  • First, we asked the model to identify similarities between abstracts from a specific domain and generate summaries.
  • Second, we incorporated these summaries into the instructions to provide the model with insights about the classes and features identified by the model itself in the first stage.

This approach allowed us to feed more training data samples into the model; and it worked – we boosted metrics by approximately 10%. Below is the prompt we used to generate these summaries:

The prompt for ChatGPT used to extract meaningful information about article domains

For each domain, we supplied seven to eight abstracts, resulting in a total of 63 distinct abstracts used to prepare the classification prompt (eight abstracts per seven classes to build summaries and seven abstracts provided as examples in the actual prompt).

Nevertheless, we instructed the model to respond with "Unknown" if it was uncertain about the class. In the validation set we observed that most "Unknown" responses corresponded to computer science articles. We then replaced all "Unknown" instances with the "Computer" class.

The resulting classification prompt read as follows:

The final prompt for ChatGPT used to classify article abstracts

Once again, performance was heavily influenced by the prompt and the samples provided. The model also generated several categories outside the target list, requiring manual adjustments to be made based on the validation set. This approach yielded the following results:

The performance was notably better than fine-tuning a RoBERTa model on 200 samples – and fewer samples were required. However, as the availability of labeled data increased, RoBERTa began to outperform this approach, even with just 500 samples.

We believe that further performance improvements are possible through proper prompt engineering. Some useful tips and tricks can be found in the Prompting Guide.

Fine-tuning a GPT-3 model

For our final approach, we fine-tuned the GPT-3 Babbage model on these three datasets. We followed the dataset preparation recommendations outlined in the OpenAI guide and opted for the default hyperparameters without making any specific adjustments. The training process for each dataset took about 20 minutes, yielding the following results:

The fine-tuned GPT-3 model delivered impressive results even on the smallest dataset, surpassing both RoBERTa and ChatGPT. As the amount of training data increased, the performance gap between RoBERTa and the tuned GPT-3 model narrowed. This raised questions about the resources and feasibility of using either option. We discussed the pros and cons of both approaches in our previous articles.

Conclusion

This experiment demonstrates that our initial hypothesis was correct – larger models trained on more extensive data perform significantly better on extra-small datasets. With proper prompt engineering and few-shot techniques, it's possible to achieve favorable results.

However, differences in performance decrease as the dataset size increases. Moreover, an appropriately tailored classical model, such as a domain-adapted RoBERTa model, can sometimes outperform generic LLMs in classification tasks. It can be attributed to the model's specialized "knowledge" of the subject matter. Furthermore, with the right optimizations, inference using these models can be significantly faster, which is crucial when developing online services.

All images unless otherwise noted are by the author.

Sources

  1. Kowsari K, Brown DE, Heidarysafa M, Jafari Meimandi K, Gerber MS, Barnes LE. HDLTex: Hierarchical Deep Learning for Text Classification. In: Machine Learning and Applications (ICMLA), 2017 16th IEEE International Conference On. IEEE; 2017.
  2. Liu Y, Ott M, Goyal N, et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach. CoRR. 2019;abs/1907.11692. http://arxiv.org/abs/1907.11692

Tags: ChatGPT Data Science Llm Machine Learning Text Classification

Comment