Improve Your RAG Context Recall by 95% with an Adapted Embedding Model.

Author:Murphy  |  View: 25530  |  Time: 2025-03-22 20:06:32

Retrieval-augmented generation (RAG) is one prominent technique employed to integrate LLM into business use cases, allowing proprietary knowledge to be infused into LLM. This post assumes you already possess knowledge about RAG and you are here to improve your RAG accuracy.

Let's review the process briefly. The RAG model consists of two main steps: retrieval and generation. In the retrieval step, several sub-steps are involved, including converting context text to vectors, indexing the context vector, retrieving the context vector for the user query, and reranking the context vector. Once the contexts for the query are retrieved, we move on to the generation stage. During the generation stage, the contexts are combined with prompts and sent to the LLM to generate a response. Before sending to the LLM, the context-infused prompts may undergo caching and routing steps to optimize efficiency.

For each of the pipeline steps, we will conduct numerous experiments to collectively enhance RAG accuracy. You can refer to the below image that lists(but is not limited to) the experiments performed in each step.

One of the major problems developers face is the heavy dip in accuracy while deploying the application in production.

RAG does best in POC and worst in production. – This frustration is common among the developers building GenAI applications.

The generation stage has been mostly sorted out with some prompt engineering. However, the main challenge is to retrieve the proper and complete context for the user's query. This is measured by a metric called context recall, which accounts for the number of relevant contexts retrieved for the given query. The goal of the retrieval stage experiments is to improve the context recall.


Embedding Model Adaption – A Holy Grail

Adapting the embedding model to your dataset is the key to significantly improving your context recall score by +95% during experiments in the retrieval stage.

Let's understand the concept behind the embedding model before adapting it. The idea starts from word vectors where we will train the model to understand the meaning of the words from their surrounding context(read more about CBOW and Skipgram). After the word vectors, the embedding models are neural networks specifically designed to capture the relationship between the texts. They extend beyond word-level understanding to grasp sentence-level semantics. It is trained using the masked language modelling objective, where a specific percentage of input texts will be masked to train the embedding model, to predict the masked words. This approach enables the model to understand deeper language constructs and nuances when trained using billions of tokens and resultant embedding models generate context-aware representations. These trained embedding models are meant to produce similar vectors for similar sentences which could then be measured using distance metrics like cosine similarity based on which retrieval context will be prioritised.

So, now we know what these models are trained for. It will produce similar embeddings for the below sentences:

Sentence 1: Roses are red

Sentence 2: Violets are blue

They are closely related because both sentences talk about colour.

For RAG, the similarity score between the query and context should be higher so that relevant contexts will be retrieved. Let's take a look at the query below and the context from the PubmedQA dataset.

Query: Do tumour-infiltrating immune cell profiles and their change after neoadjuvant chemotherapy predict the response and prognosis of breast cancer?

Context: Tumor microenvironment immunity is associated with breast cancer outcomes. A high lymphocytic infiltration has been associated with response to neoadjuvant chemotherapy, but the contribution to response and prognosis of immune cell subpopulation profiles in both pre-treated and post-treatment residual tumours is still unclear. We analyzed pre- and post-treatment tumour-infiltrating immune cells (CD3, CD4, CD8, CD20, CD68, Foxp3) by immunohistochemistry in a series of 121 breast cancer patients homogeneously treated with neoadjuvant chemotherapy. Immune cell profiles were analyzed and correlated with response and survival. We identified three tumour-infiltrating immune cell profiles, which were able to predict the pathological complete response (pCR) to neoadjuvant chemotherapy (cluster B: 58%, versus clusters A and C: 7%). A higher infiltration by CD4 lymphocytes was the main factor explaining the occurrence of pCR, and this association was validated in six public genomic datasets. A higher chemotherapy effect on lymphocytic infiltration, including an inversion of CD4/CD8 ratio, was associated with pCR and with a better prognosis. Analysis of the immune infiltrate in post-chemotherapy residual tumour identified a profile (cluster Y), mainly characterized by high CD3 and CD68 infiltration, with a worse disease-free survival.

Do the query and context look similar? Are we using embedding models in the way they are designed to be used? Clearly, no!

Left Image by Author; Right Image Credits: https://github.com/UKPLab/sentence-transformers/blob/master/docs/img/SemanticSearch.png, Apache-2.0 license

The reason we need to fine-tune the embedding model is to ensure that the representations of the query and relevant contexts are closer together. Why not train from scratch? It's not efficient because embedding models already have an understanding of language constructs from billions of token training, which can still be leveraged.


Finetuning an Embedding Model

In order to refine an embedding model, we require data consisting of queries similar to anticipated user queries and pertinent company documents. We can utilize the Language Model (Llm) to generate queries based on the knowledge base documents. Training LLM with a company's knowledge base is akin to providing a shortcut, as it allows the embedding model to access context during the training phase itself.

Preparing the Dataset – Train and test:

Here are the steps for data preparation:

For the training set:

  1. Mine all possible questions from the company's Knowledge Base using LLM.
  2. If you're chunking the knowledge base, ensure that questions are mined from all the chunks.

For the testing set:

  1. Mine a smaller number of questions from the knowledge base.
  2. If available, use real user questions.
  3. Paraphrase the questions in the training set.
  4. Combine and paraphrase questions from both the training and testing sets.

The majority of us do not work on developing domain-wide embedding models. The embedding models we create are intended to perform better on the company's knowledge base. Therefore, there is no harm in training embedding models using the company's internal dataset.

For this article, we will use the "qiaojin/PubMedQ" dataset from Hugging Face which contains columns such as pubid, question, and context. pubid will be used as the question IDs.

from datasets import load_dataset
med_data = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split="train")
med_data

The pubid is a unique id that refers to the row. We will use the pubid as the question ID.

To train the embedding model, we will train using the sentence-transformer trainer but you can also use the huggingface trainer. Furthermore, we are using MultipleNegativeRankingLoss to finetune our model but, the same effect can be achieved using a variety of loss such as TripletLoss, ContrastiveLoss, etc. But, for each loss data needed will be different. For example, for tripletloss, you need (Query, Positive Context, Negative Context) pairs whereas in MultipleNegativeRankingLoss you need only (Query, Positive Context) pairs. All contexts except the positive one for the given Query will be treated as negative.

In our PubMedQA dataset, each row's "question" column contains a question, and the "context" column contains a list of suitable contexts for that question. Therefore, we need to expand the context list column and create individual rows with their corresponding context in a new column.

dataset = med_data.remove_columns(['long_answer', 'final_decision'])

df = pd.DataFrame(dataset)
df['contexts'] = df['context'].apply(lambda x: x['contexts'])

# Flatten the context list and repeat the question
expanded_df = df.explode('contexts')

# Optionally: Reset index if needed
expanded_df.reset_index(drop=True, inplace=True)

expanded_df = expanded_df[['question', 'contexts']]
splitted_dataset = Dataset.from_pandas(expanded_df)
                   .train_test_split(test_size=0.05)

expanded_df.head()

Preparing the Dataset – Eval:

Now, we have the dataset for training and testing. Let's form the dataset for evaluation. For evaluation, we will mine the questions from the context using LLM by which we could get a realistic idea of how well our embedding model is improved. From the PubMedDataset, we will take the first 250 rows, join the list of contexts into one string per row and send it to the LLM for mining the questions. So, for each row, LLM might output somewhere around 10 questions. Thus, we will have ~2500 question context pairs for evaluation.

from openai import OpenAI
from tqdm.auto import tqdm

eval_med_data_seed = med_data.shuffle().take(251)

client = OpenAI(api_key="")

prompt = """Your task is to mine questions from the given context.
Example question is also given to you. 
You have to create questions and return as pipe separated values(|)


{context}



{example_question}

"""

questions = []
for row in tqdm(eval_med_data_seed):

    question = row["question"]
    context = "nn".join(row["context"]["contexts"])
    question_count = len(row["context"]["contexts"])

    message = prompt.format(question_count=question_count,
                            context=context,
                            example_question=question)

    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {
                "role": "user",
                "content": message
            }
        ]
    )

    questions.append(completion.choices[0].message.content.split("|"))

eval_med_data_seed = eval_med_data_seed.add_column("test_questions", questions)
df = eval_med_data_seed.to_pandas()

eval_data = Dataset.from_pandas(df.explode("test_questions"))
eval_data.to_parquet("test_med_data2.parquet")

Before we start the training, we need to prepare the evaluator using the evaluation dataset created above.

Preparing the Evaluator:

The sentence transformer library offers various evaluators such as EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator, and InformationRetrievalEvaluator. For our specific use case of training the embedding model for RAG, the InformationRetrievalEvaluator is the most suitable choice. Additionally, multiple evaluators can be added and used for scoring.

Given a set of queries and a large corpus set, the Information Retrieval Evaluator will retrieve the top-k most similar document for each query. The Information Retrieval Evaluator will assess the model using various metrics such as Recall@k, Precision@k, MRR, and Accuracy@K, where k will be 1, 3, 5, and 10. For the RAG, the Recall@K metric is the most important, as it indicates how many relevant contexts the retriever can successfully fetch. This is critical because if the retriever can retrieve the correct contexts, the generation will likely be accurate, even if we have additional irrelevant contexts.

eval_context_id_map = {}

for row in eval_data:
    contexts = row["context"]["contexts"]
    for context, context_id in zip(contexts, row["context_ids"]):
        eval_context_id_map[context_id] = context

eval_corpus = {} # Our corpus (cid => document)
eval_queries = {}  # Our queries (qid => question)
eval_relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])

for row in eval_data:
    pubid = row.get("pubid")
    eval_queries[pubid] = row.get("test_questions")
    eval_relevant_docs[pubid] = row.get("context_ids")

    for context_id in row.get("context_ids"):
        eval_corpus[context_id] = eval_context_id_map[context_id]

_Queries: M_aps each publication ID to its corresponding question.

Corpus: Maps each context ID to its content from the context map.

Relevant docs: Associate each publication ID with a set of relevant context IDs.

After forming all the dictionaries, we can create an InformationRetrievalEvaluator instance from the sentence_transformer package.

ir_evaluator = InformationRetrievalEvaluator(
    queries=eval_queries,
    corpus=eval_corpus,
    relevant_docs=eval_relevant_docs,
    name="med-eval-test",
)

Model Training:

At last, let's train our model and it is simple using a sentence-transformer trainer. Just set training configuration parameters such as

  1. eval_steps – to specify how often your model has to be evaluated.
  2. save_steps – to specify how often your model has to be saved.
  3. num_train_epochs – Number of epochs to train
  4. per_device_train_batch_size – It is a batch size in the case of a single GPU.
  5. save_total_limit – to specify the maximum allowed save model.
  6. run_name – the logs will be posted in wandb.ai, so run name is essential.

Then, we pass our args, train dataset, test dataset, loss function, evaluator, and model name to the trainer. Now you can sit back and relax till the training is completed.

Relax: You are a good man, Arthur!

For our training data, it took around 3 hours to train the model which includes testing data set and evaluation dataset's inference time.

# Load base model
model = SentenceTransformer("stsb-distilbert-base")
output_dir = f"output/training_mnrl-{datetime.now():%Y-%m-%d_%H-%M-%S}"

train_loss = MultipleNegativesRankingLoss(model=model)

# Training arguments
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir, num_train_epochs=1, per_device_train_batch_size=64,
    eval_strategy="steps", eval_steps=250, save_steps=250, save_total_limit=2,
    logging_steps=100, run_name="mnrl"
)

# Train the model
trainer = SentenceTransformerTrainer(model=model, 
                                     args=args, 
                                     train_dataset=splitted_dataset["train"], 
                                     eval_dataset=splitted_dataset["test"], 
                                     loss=train_loss,
                                     evaluator=ir_evaluator)

trainer.train()
Full results on the notebook attached at the end

Results

For comparison, let's initiate two instances of the model, one with trained weights and another with untrained weights.

untrained_pubmed_model = SentenceTransformer("stsb-distilbert-base")
trained_pubmed_model = SentenceTransformer("/kaggle/input/sentencetransformerpubmedmodel/transformers/default/1/final")
ir_evaluator(untrained_pubmed_model)
ir_evaluator(trained_pubmed_model)

The results are so clear and we have astounding improvements on every metric. Below are the improvements for the metric we care about:

  • recall@1–78.80 %​ over untrained model
  • recall@3–137.92 %​ over untrained model
  • recall@5–116.36 %​ over untrained model
  • recall@10- 95.09 % over untrained model

After analyzing the results, it is apparent that the embedding model enhances context recall, thereby remarkably improving the overall accuracy of RAG generation. However, one drawback is the need to monitor the addition of documents to the knowledge base and periodically retrain the model.

This can be achieved by adhering to the standard machine learning pipeline process, where we monitor the model for any drift and reinitiate the training pipeline if the drift exceeds a certain threshold.

Notebook: https://www.kaggle.com/code/vigneshboss/embedding-model-training-blog?scriptVersionId=200526579

Try to implement the idea and kindly comment on the result about performance improvement.

Kindly follow, clap, and share the content!

Unless otherwise noted, all images are by the author

References:

  1. Domain Adaptation to Proprietary Data Adaptation idea is derived from: GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval
  2. RAG Evaluation – https://www.pinecone.io/learn/series/vector-databases-in-production-for-busy-engineers/rag-evaluation/
  3. SBERT Training – https://sbert.net/examples/training/ms_marco/cross_encoder_README.html

Tags: AI Generative Ai Tools Large Language Models Llm Retrieval Augmented

Comment