Optimizing LLM Classification task with BERT and XGBoost: A Cost-Effective Solution for SQL and Self-Reference Identification
Keywords
BERT, XGBoost, Chatbot query classification, SQL query detection, Self-reference detection, NLP (Natural Language Processing), Transformer models, Machine learning efficiency, API cost optimization, Conversational AI, Binary classification, Pre-trained models
Introduction:
In the growing need to develop conversational agents, handling a wide range of user queries efficiently is a key challenge. Although LLMs excel at generating responses, they often introduce latency and high costs, especially in production systems that handle large volumes of queries. To address these concerns, we explored a more optimized approach by creating custom machine-learning models that efficiently classify user queries into two specific tasks.
In this case, we have a virtual assistant system that in one of its initial modules has, in addition to other tasks, the task of quickly understanding based on the user's question, two types of classification tasks, the first is to know if to answer your question we need to execute a SQL query or not, binary response, and also based on the same input detect if the user is self-referencing or not.
The first kind of model that we evaluate is designed to detect whether a user’s query requires the execution of an SQL query, and the second classifies whether the user is referring to themselves or another entity in their query (self-reference detection). Both models were trained using a combination of a pre-trained [1](Bidirectional Encoder Representations from Transformers) model for feature extraction and an XGBoost classifier for decision-making. This approach significantly reduces the time and cost associated with using an external API for these tasks.
By replacing external LLM calls with in-house models, we achieved faster decision-making and reduced the computational overhead, making the system more scalable and efficient. The following sections outline the methods used to build these models, the results of our experiments, and a discussion on the performance and implications of this approach.
Method
This section describes the methodology used to create both the SQL Query Detection Model and the Self-Reference Detection Model. Each model consists of two main components: the BERT encoder, which extracts contextualized features from the input text, and the XGBoost classifier, which performs the final classification task.
Dataset Preparation
Both models were trained using curated datasets that simulate real user queries, each labeled according to the classification task. The dataset for the SQL Query Detection Model contains questions that either require an SQL query execution or not, while the Self-Reference Detection Model focuses on identifying whether a query refers to the user (self-reference) or an external entity.
Example of the self-reference dataset:
The dataset was preprocessed in both cases to remove noise and enhance model performance.
Preprocessing was critical to ensuring the input data was clean and ready for model training. This included:
The following function demonstrates the text cleaning and stemming process applied to each query:
SQL Query Detection Model
Dataset Overview:
The dataset used for the SQL query detection model was imbalanced, with 84.36% of the queries classified as requiring an SQL query (True) and 15.64% not requiring one (False). This imbalance posed a challenge for model training, leading to the application of SMOTE (Synthetic Minority Over-sampling Technique) to balance the dataset.
Dataset's characteristics:
Number of characters per sentence:
Number of words per sentence:
Data Balancing with SMOTE: To address the dataset imbalance, SMOTE was applied, resampling the minority class to create a more balanced dataset. The resampled dataset shape was (410, 768), effectively doubling the number of samples in the minority class.
Hyperparameter Tuning
Self-Reference Detection Model
Dataset Overview:
The self-reference detection dataset was similarly imbalanced, with 88.18% of queries classified as non-self-referential (False) and 11.82% as self-referential (True). As with the SQL query detection model, I analyzed the textual data:
Number of characters per sentence:
Number of words per sentence:
Data Balancing with SMOTE: The imbalance in the dataset was corrected using SMOTE, which resampled the dataset to include more examples of self-referential queries. The final resampled dataset had a shape of (388, 768).
Hyperparameter Tuning: A grid search was performed to find the best hyperparameters for the XGBoost model.
Recommended by LinkedIn
Model Architecture
For both models, we utilized the pre-trained BERT model to generate embeddings for the input text. These embeddings capture the context of the query, which is crucial for classification tasks.
BERT’s bidirectional nature allows the model to consider both the preceding and succeeding words when generating embeddings, making it highly effective for understanding the meaning behind user queries.
After obtaining the embeddings from BERT, an XGBoost model was used to classify the queries. XGBoost was chosen due to its efficiency and effectiveness in handling structured data with complex relationships.
The XGBoost classifier was trained using these embeddings as input, learning to distinguish between SQL query-related questions and self-referential queries.
Training and Hyperparameter Tuning
Both models were trained using standard train-test splits, with a grid search employed to find the optimal hyperparameters for XGBoost. The parameters tuned include:
The training was performed on a balanced dataset, and the performance was measured using precision, recall, and F1 scores.
Results
After training both models, we evaluated their performance using a separate test set. The results for the Self-Reference Detection Model are presented below:
Confusion Matrix:
Classification Report:
The model achieved perfect accuracy on the test set, with precision, recall, and F1-scores of 1.0 for both classes.
For the SQL Query Detection Model, similar results were achieved, with accuracy consistently exceeding 95% across various validation sets. This indicates that the models successfully replaced the external LLMs for these tasks without sacrificing performance.
Time Measurement
From the perspective of time comparison, we can see that using the trained model version v1.16.0-dev.1_dev (using the trained model) the times are lower than with version v1.15.1-rc.2_prod (using external calls to an LLM)
Cost Comparison
The machine learning model I trained based on BERT eliminates the need to make costly API calls to large language models (LLMs) like GPT-4o for text classification tasks. In the previous approach, each classification task required two calls to the LLM, with each call involving 6,000 tokens. According to the pricing for GPT-4o, these two calls would have cost approximately $0.15 per classification. By switching to an internally hosted BERT-based model, we avoid this recurring expense, resulting in significant cost savings, especially when scaled over a large number of classifications.
With 100 users asking 30 questions per day, the impact on costs becomes significant. For each classification, two calls are made to GPT-4o, each costing approximately $0.15. This results in a daily cost of $450, and when scaled to a full month, the total rises to around $13,500. By transitioning to an internally trained BERT-based model, we can eliminate this recurring expense, yielding substantial savings, especially in scenarios with high user engagement.
Discussion
The integration of BERT with XGBoost for binary classification of user queries proved to be highly effective. By employing these models, we significantly reduced the response time and cost associated with making external LLM calls. The models were able to quickly and accurately classify whether a user query required SQL query execution or was self-referential, which allowed for more efficient handling of common user interactions within the chatbot system.
The BERT embeddings captured the subtle contextual cues in the text, while XGBoost leveraged these features to make accurate binary classifications. This architecture allowed for both scalability and flexibility in adapting the models to new datasets or classification tasks, should the chatbot’s functionality expand in the future.
Conclusion
In conclusion, the decision to replace external LLM calls with these custom models not only improved system performance but also aligned with cost-saving strategies. Future work may involve exploring lightweight models that further reduce inference time, as well as experimenting with other architectures like RoBERTa or ALBERT for potential performance gains.
Marcos Esteban Soto, cost-efficient approach. Interesting tradeoffs between complexity and scalability. Questions: User experience impact? Maintenance considerations?