SentenceTransformer: A Model For Computing Sentence Embedding

Author:Murphy  |  View: 30092  |  Time: 2025-03-22 23:15:09

In this post, we look at SentenceTransformer [1] which was published in 2019. SentenceTransformer has a bi-encoder architecture and adapts BERT to produce efficient sentence embeddings.

BERT (Bidirectional Encoder Representation of Transformers) is built with the ideology that all NLP tasks rely on the meaning of tokens/words. BERT is trained in two phases: 1) pre-training phase where BERT learns the general meaning of the language, and 2) fine-tuning where BERT is trained on specific tasks.

Image taken from [3]

Bert is very good at learning the meaning of words/tokens. But It is not good at learning meaning of sentences. As a result it is not good at certain tasks such as sentence classification, sentence pair-wise similarity.

Since BERT produces token embedding, one way to get sentence embedding out of BERT is to average the embedding of all tokens. The SentenceTransformer paper [1] showed this produces very low quality sentence embeddings almost as bad as getting GLOVE embeddings. These embeddings do not capture the meaning of sentences.

Image by author

In order to create sentences embeddings from BERT that are meaningful, SentenceTransformer trains BERT on few sentence related task such as:

  1. NLI (natural language inferencing): This task receives two input sentences and outputs either "entailment", "contradiction" or "neutral". In case of "entailment" sentence1 entails sentence 2. In case of "contradiction" sentence1 contradicts sentence2. And in the third case which is "neutral" the two sentences have no relation.
  2. STS (sentence textual similarity): This task receives two sentences and decides the similarity of them. Often similarity is calculated using cosine similarity function.
  3. Triplet dataset

Training BERT on NLI (classification objective)

SentenceTransformer train BERT on NLI task using a Siamese network. Siamese means twins and it consists of two networks of the exact same architecture that they share weight too.

Image by author

The first sentence u is passed in first network, the second sentence v is passed into second network. A mean pooling averages the token embeddings after each BERT. This result is 768-dim sentence embedding. Let's call it emb(u) and emb(v) respectively. At the end, we concatenate the following three vectors: emb(u), emb(v), emb(u)-emb(v). This results in a vector that is 3*768-dim. We then pass it through a dense layer that maps it to 3 neurons with softmax activation. Each neuron correspond to either entailment, or contradiction or neutral.

Image from [1]

To train this network we use cross entropy loss.

Training BERT on STS (regression objective)

Sentence textual similarity task receives two sentences and computes their similarity. The network architecture for fine-tuning BERT on STS is as following. It is again a siamese network with mean pooling on top.

Image by author

After u and v are passed through BERT and mean pooling layer, we have emb(u) and emb(v) both of 768-dim. We then compute cosine-similarity between them which is going to be a score in range (-1,1) .

To train this network we minimize squared error loss between the true similarity and the predicted similarity.

Training BERT on Triplet dataset (triplet objective)

In triplet objective, the model receives an anchor data point, a positive data point that is related or close to the anchor, and a negative data point that is unrelated to the anchor.

To collect this data in text domain, we can pick a random sentence from a document as anchor, pick its following sentence as positive and pick a random sentence from a different passage as negative.

Now the whole point is to train the network such that the distance between anchor and positive i.e. |a-p| remains smaller than distance between anchor and negative i.e. |a-n|. We often put a margin eps and say we want |a-p| to be smaller than |a-n| - eps. So if |a-p|<|a-m|-eps then the loss=0 . In other words, if |a-p|-|a-m|+eps < 0 then loss=0 , otherwise loss>0 .

Therefore the loss function becomes : Loss: = max (0, |a-p|-|a-m|+eps)

image by author

Just a reminder that triplet loss explained above is a more advanced version of contrastive loss where we input the model two sentences a and b that are either similar or dissimilar. Contrastive loss defined as following:

  • If a and b are similar then we minimize |a-b| .
  • If they are dissimilar then we maximize |a-b|. If their distance is larger than a threshold m i.e. |a-b| > m then, we don't maximize it anymore (that is different enough for us). Therefore if -|a-b| <-m or in other words m-|a-b| <0 the loss is zero for us.

So the loss function becomes y |a-b| + (1-y) max(0, m-|a-b|) where y=1 if similar, otherwise y=0.

A network trained on triplet loss often converge faster than same network trained using contrastive loss.

At Inference Time

Regardless of what method we use for fine-tuning BERT on sentence understanding tasks, after training the model, we use one of the towers (BERT + pooling layer) to create sentence embedding for all sentences in the training corpus. We keep all in an index structure.

Then at inference time, we pass the query sentence through the same model, get the embedding and then retrieve K nearest neighbors of this sentence from the index data structure. The metric for similarity of KNN is often cosine similarity.

SentenceTransformer in Code

Let's use mrpc (Microsoft Paraphrasing Corpus) [4] to train a sentence transformer. This dataset contains two sentences and a label which indicates whether two sentences mean the same thing.

To use SentenceTransformer, we first install the library:

!pip install sentence_transformers

Then we build the model. Building the model is very easy, it consists of three steps:

  1. load an existing language model
  2. build a pooling layer over tokens
  3. join above two steps using module argument and pass it to sentenceTransformer

Let's put this to code:

# Define model
## Step 1: use an existing language model
word_embedding_model = models.Transformer('bert-base-uncased')

## Step 2: use a pool function over the token embeddings
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 
                               pooling_mode = 'cls',
                               pooling_mode_cls_token=True, 
                               pooling_mode_mean_tokens = False)

## Join steps 1 and 2 using the modules argument
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In this code, we are using BERT model as the transformer. Then we use the CLS token in the pooling layer to get the embedding.

Next, let's load the dataset. We use MRPC dataset which is under CC-by-V4 license ; that is Creative Common (CC) license which is a common open-source license.

from datasets import load_dataset

dataset = load_dataset("glue", "mrpc")
Image by author

You see the dataset consists of three splits; each split has two sentences and a label. We build the training data as a list of InputExample . Every InputExample takes texts and label as arguments.

# Format training data
train_examples = []
for example in dataset['train']:
    train_examples.append(InputExample(texts=[example['sentence1'], example['sentence2']], label=float(example['label'])))

We then pass this list of _trainexamples into a data loader:

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=4)

Next, we choose the train loss:

train_loss = losses.ContrastiveLoss(model=model)

Now, we move to make the evaluation dataset:

# Format evaluation data
sentences1 = []
sentences2 = []
scores = []
for example in dataset['validation']:
    sentences1.append(example['sentence1'])
    sentences2.append(example['sentence2'])
    scores.append(float(example['label']))

and we use binaryClassificationEvaluator for the evaluation. See the full list of evaluations here.

evaluator = evaluation.BinaryClassificationEvaluator(sentences1, sentences2, scores)

We then train the model by calling the fit function:

# Start training
model.fit(
    train_objectives=[(train_dataloader, train_loss)], 
    evaluator=evaluator,
    evaluation_steps=500,
    epochs=1, 
    warmup_steps=0,
    output_path='./sentence_transformer/',
    weight_decay=0.01,
    optimizer_params={'lr': 0.00004},
    save_best_model=True,
    show_progress_bar=True,

)

Now that your model is trained, you can compute sentence embedding on any ad-hoc sentence as following:

sentences = ['This is just a random sentence on a friday evenning', 'to test model ability.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

print(embeddings)

And if you want to compute the accuracy of your model on the test data you can do the following:

from sentence_transformers import util

correct = 0
for row in dataset['test'].select(range(100)):
    u = model.encode(row['sentence1'])
    v = model.encode(row['sentence2'])
    cos_score = util.cos_sim(u, v)[0].numpy()[0]
    if cos_score > 0.5 and row['label'] == 1:
        correct += 1
    if cos_score <= 0.5 and row['label'] == 0:
        correct += 1

print(correct/100)

This would compute the accuracy as total number of corrected predictions divided by number of test data points.

Conclusion

In this post, we looked at sentenceTransformer library and paper and we saw how it addresses the problem of computing sentence embedding from BERT. SentenceTransformer fine-tune BERT on three sentence related dataset namely NLI, STS and triplet datasets in a siamese and triplet architecture to ensure the model learns meaningful sentence embeddings.


If you have any questions or suggestions, feel free to reach out to me: Email: [email protected] LinkedIn: https://www.linkedin.com/in/minaghashami/

References

  1. Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
  2. sbert repository
  3. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
  4. https://huggingface.co/datasets/glue/viewer/mrpc
  5. cross-encoder vs bi-encoder

Tags: AI Bert Deep Learning Llm Machine Learning

Comment