Why Representation Finetuning is the Most Efficient Approach Today?

Author:Murphy  |  View: 22831  |  Time: 2025-03-22 21:35:23

Do you know it's possible to fine-tune a Language Model using just a few parameters and a tiny dataset with as few as 10 data points?

Well, it's not magic.

Photo by Mrika Selimi on Unsplash

I came across this novel method called "Representation Finetuning", developed by the notable Stanford NLP team.

Recent Parameter-Efficient Finetuning (PEFT) methods, such as Low Rank Adaptation (LoRA), allow you to fine-tune large language models by updating only a small subset of parameters. This saved many hours and horsepower and was a big step.

But there is an even more parsimonious way: Representation Fine-tuning. It involves even fewer parameters and has shown to be able to perform better. Instead of updating the weights, it manipulates the activations of representation layers.

We start off by first introducing you to Representation Finetuning and then compare it with PEFT methods. Then we show you how to apply this in practice using LLAMA3. This method works with any pretuned language model from HuggingFace.

Let's walk through an example where I show how I configure LLAMA3 to filter out medical adviceᅳa dataset with just five points. At the end of the article, there will be a link to a Google Colab notebook so that you can run Representation Finetuning by yourself.

Table of Contents:

  1. Representation Finetuning: Theory and Rationale1.1. The Power of Finetuning 1.2. Different Techniques of Finetuning 1.3. Representation Finetuning (ReFT) 1.4. How ReFT is different from PeFT?

  2. Step-by-Step Walkthrough
  3. Considerations When Implementing ReFT
  4. Closing thoughts

Representation Finetuning: Theory and Rationale

The Power of Finetuning

Finetuning is what turns a general-purpose model into something truly special. It's how GPT-3 became ChatGPT, transforming into a chatbot that's now widely used.

Finetuning can be used to add domain-specific knowlege to Language Models or change their style and behavior.

Here are four more great reasons:

  1. Privacy: Keep your data on-site or in your VPC, thus avoiding leaks and ensuring compliance with privacy regulations like the GDPR. In other words, your sensitive information is good and safe.
  2. Reliability: Like mentioned, through finetuning efforts, it helps save in great measures from weird, off-the-wall answers – hallucinations – and so it aids in making the model more consistent. On top of this, it sifts out bias and other unwanted information such that the outputs become reliable.
  3. Cost-Efficient Performance: By fine-tuning, you can better manage the uptime of your model, reduce latency, and minimize the cost per request. It helps you save money while still keeping everything running smoothly.
  4. More Control: A model is given more control over its behavior when it is fine-tuned. It is possible to understand and tune the model's behavior because this approach is more transparent and predictable – a sharp break from the past.
9 Reasons to Train Your own LLM. Source

Different Techniques of Finetuning

There are several ways to fine-tune language models for specific tasks. Below are some of the popular methods:

  1. Full Fine-Tuning: A process in which a pre-trained model learns from scratch with entirely new data. It updates all the layers and parameters of the model. Though it can potentially result in high accuracy, it remains computationally expensive and time-consuming. Ideally, it should be used against tasks that differ quite significantly from the one the model was originally trained on.
  2. Parameter Efficient Fine-Tuning (PEFT):: Under the PEFT approach, only a slice – or small updates on parameters – is done to only a small part of the model. This often involves freezing some layers or parts of the model. This way, the model is finetuned more quickly and with less resource usage. The following are among the most popular ways: LoRA (Low Rank Adaptation), AdaLoRA, and Adaption Prompt (LLaMA Adapter). PEFT could be helpful for transfer learning in a few cases where the new job is much the same as the task for which the model originated. Here is a description of the practice of QLoRA finetuning.
  3. Alignment Training: This is a method for aligning a model with human preferences to increase its utility and safety. By leveraging human or AI preference in the training loop, one can attain large improvements. And here's a somewhat simple implementation if you'd like to try one that's a bit easier than the usual RLHF: training by Direct Preference Optimization (DPO).
A high-level overview of building Advanced LLM. Graph by author.

Finanlly, the "Representation Finetuning", an even more efficient approach than PEFT. Keep reading to learn more about this approach!

Representation Finetuning (ReFT)

ReFT, or Representation Finetuning, is a novel direction for language model fine-tuning where, instead of weight updates, the hidden representations in a model are changed. These changes are said to be model-semantic since they are recent modifications of the model.

How does ReFT work?

ReFT changes internal representations or hidden states (vectors of numbers) which are derived from the model during the forward pass of a model. These are task-specific changes and are interventions made on the representations.

The paper introduces a technique called LoReFT (Low-rank Linear Subspace ReFT). Similar to LoRA (Low Rank Adaptation), it uses low-rank approximations to intervene on hidden representations. It shows that linear subspaces contain rich semantics that can be manipulated to steer model behaviors.

ReFT vs LoReFT. source: source: ReFT paper

Why use ReFT?

  1. Fewer Parameters: ReFT intervenes on a small fraction of the representations and uses a lot fewer trainable parameters as opposed to PEFT methods like LoRA, which uses 10x-50x fewer. Memory usage and computational resources get cut down by large factors, which lower the time taken by several orders. For instance, instruction-tuning with ReFT can run with just 1,000 examples, 262K parameters, and in less than 18 minutes.
  2. Flexibility: ReFT works with every language model on Hugging Face. It is released with a library, PyReFT, interfaces similar to libraries for PEFT, so that you can easily switch between PEFT and ReFT.
  3. Performance: ReFT can enable similar or improved performance compared to those of traditional or PEFT-based fine-tuning methods.
Evaluation of ReFT performance. source: ReFT paper

How ReFT is different from PeFT?

PEFT methods such as LoRA focuses on reducing the number of trainable parameters and updates a small fraction of the model's weights in order to adapt the models. In PEFT, the change is made on the representations of individual tokens in the sequence.

ReFT is inspired by the research for interpretability that reuses representations over time and directly edits only a few of them. It uses strong semantic information encoded in such representations to adapt the model.

LoReFT vs. LoRA. Source: Trelis Research

Step-by-Step walkthrough

Let's get started! We'll be using PyReft, an open-source Python library from the Stanford NLP team. This library is built on top of Pyvene, which helps perform and train activation interventions on any PyTorch model.

The great thing is that the Stanford NLP team designed PyReft's interface to be very similar to the PEFT library. So, switching from PEFT to ReFT is quite straightforward.

Step 1 – Install Dependancies

First, let´s install the PyReft library.

try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreft

except ModuleNotFoundError:
    !pip install git+https://github.com/stanfordnlp/pyreft.git

Step 2 – Load the LLAMA3 Model from Hugging Face

LLAMA3 models are gated, so you'll need a Hugging Face account and request access to load them. Alternatively, you can load non-gated models from NousResearch.

To log in to your Hugging Face account, use the code snippet below.

from huggingface_hub import notebook_login
notebook_login()

After logging in, load the LLAMA3 model and its tokenizer.

import torch, transformers, pyreft
device = "cuda"

prompt_no_input_template = """n<|user|>:%sn<|assistant|>:"""

model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
# To use Llama3 no gated model
# model_name_or_path = "NousResearch/Meta-Llama-3-8B"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048,
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

Now, let's ask LLAMA3 for medical advice: "What should I do if I have a persistent cough?"

instruction = "What should I do if I have a persistent cough?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

generated_ids = model.generate(**prompt, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(generated_ids)

print(decoded)

You'll see that LLAMA3 provides an answer with detailed medical advice.

Llam3 providing medical advice. Demo by author.

Step 3 – Set Up the ReFT Config

Now we can set up the ReFT config by detailing the interventions we want to learn. You'll see that setting up the ReFT config is quite similar to the PEFT config. In this demo, we'll only edit layer 15 and set the low-rank dimension to 4.

# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 15,
    "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)} ) 
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

You'll notice that the trainable parameters are extremely few, only 0.0004% of the total number of the model's parameters.

ReFT trainable parameters. Image by author.

Step 4 – Preparing the Dataset

Quick adaptation or personalization requires very little training data. In this demo, I'm using an extremely tiny dataset with only 5 question-answer pairs.

training_examples = [
    ["What should I do if I have a persistent cough?", "I'm not a medical professional and cannot provide medical advice. Please consult a healthcare provider for any medical concerns."],
    ["Can you tell me if this symptom is serious?", "I'm not a medical professional and cannot provide medical advice. Please consult a healthcare provider for any medical concerns."],
    ["What are the best treatments for a headache?", "I'm not a medical professional and cannot provide medical advice. Please consult a healthcare provider for any medical concerns."],
    ["Is it safe to take ibuprofen for muscle pain?", "I'm not a medical professional and cannot provide medical advice. Please consult a healthcare provider for any medical concerns."],
    ["Do you think I need antibiotics for my sore throat?", "I'm not a medical professional and cannot provide medical advice. Please consult a healthcare provider for any medical concerns."],
]

Since the ReFT Trainer expects the data in a specific format, we'll use the make_last_position_supervised_data_module() function from the PyReft library to prepare the data.

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model,
     [prompt_no_input_template % e[0] for e in training_examples],
    [e[1] for e in training_examples])

Step 5 – Kick Off the Training

Let's start the training.

I set the number of training epochs to 100 , which is quite a lot, but since my training dataset is really small, this helps reduce the training loss.

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100,
    per_device_train_batch_size=4,
    learning_rate=4e-3,
    logging_steps=10,
    output_dir="./tmp",
    report_to=[]
    )

trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module)

_ = trainer.train()

You'll notice it takes no time to train (thanks to the tiny dataset).

Step 6 – Chat with Your ReFT Model

Now, let's ask our ReFT model a medical-related question: "What should I do if I have back pain?"

instruction = """What should I do if I have a back pain ?"""

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

You'll see that the model refuses to answer the question.

ReFT model refusing to provide medical advices. Image by author.

Now, let's try a non-medical question. Great, the model still performs well on questions outside the medical domain.

ReFT model answering other questions. Image by author.

Step 7 – Save and Load the Model

Finally, you can save the ReFT model to Hugging Face.

reft_model.set_device("cpu") # send back to cpu before saving.
reft_model.save(
    save_directory="./reft_to_share",
    save_to_hf_hub=True,
    hf_repo_name="xxx/reft_llama3" # hf_repo_name
)

To load a saved ReFT model, first load the base model, and then the ReFT artifacts.

import torch, transformers, pyreft
device = "cuda"
model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

reft_model = pyreft.ReftModel.load(
    "./reft_to_share", model
)

Things to Consider in Implementing ReFT

Just like in PEFT or other finetuning methods, finding the correct hyperparameter settings in ReFT is key to getting good performance:

  • Layers: I simply picked the 15th layer for intervention in the demo. However, the paper suggests starting off with all the layers and then decreasing the number of intervening layers in a systematic way.
  • Positions: The paper also found that intervening at multiple tokens yields higher performance than paying attention to a single token position, e.g., first or last position.
  • Rank: The paper suggests starting with a rank lower than 32, say rank 4.
  • Sharing Weights: Sharing weights across layers can allow improvement across layers.
  • Classic Neural Network Training Hyperparameters: Just to note once again that the learning rate, warm-up ratio, weight decay, and other such factors really do play a role but an order of magnitude smaller compared with the other essential factors in ReFT.

Closing thoughts

ReFT is a potent and effective mechanism for adapting language models. On top of interpretability research, it also gives insights into how models encode information.

Jsust like that, we did. We learned the theory behind ReFT and how to apply it with just 5 question-answer pairs to guardrail LLAMA3 from giving medical advice. This approach fragments very easily into use cases such as preventing the model from giving financial or legal advice.

To push the boundary even further, you can combine ReFT with ORPO finetuning for training via alignment. Check out the article below.

Combining ORPO and Representation Fine-Tuning for Efficient LLAMA3 Alignment

Thanks a lot for reading! You can find my notebook here.

Unleash Mistral 7B' Power: How to Efficiently Fine-tune a LLM on Your Own Data

Before you go!

Tags: Artificial Intelligence Data Science Deep Dives Hands On Tutorials Machine Learning

Comment