To Mask or Not to Mask: The Effect of Prompt Tokens on Instruction Tuning
[link to full code on [GitHub](https://github.com/davidsvaughn/prompt-loss-weight)
] _ [reach me on [LinkedIn](https://www.linkedin.com/in/davidsvaughn/)_
]
In the last several months I've noticed quite a few discussions, [here](https://x.com/corbtt/status/1806336011804484017) and there, even over here, on the question of whether or not to zero-mask (ignore) prompt tokens when fine-tuning on prompt-completion style data (i.e. instruction-tuning). I've seen various terms used, such as:
- instruction-masking
- prompt-masking
- user-masking
- completion-only-training
Whatever you call it, there seems to be no clear consensus about what the standard practice should be. Depending on which open source library you use for fine-tuning, the defaults can vary widely.
For example, the Axolotl library masks prompt tokens by default (through it's train_on_inputs=False
default setting). However, the very popular HuggingFace Trainer does not mask prompt tokens by default. One can choose to mask out the prompt by using [DataCollatorForCompletionOnlyLM](https://huggingface.co/docs/trl/main/en/sft_trainer#train-on-completions-only)
, but this comes with some significant limitations — notably, the lack of support for sample packing – which can be a deal-breaker when dealing with large datasets, as it was for me. (Note: a nice solution was proposed here).
Many guides, demos, notebooks, tutorials, etc. for LLM fine-tuning that I have come across do not mention prompt-masking, for example:
- How to Fine-Tune LLMs in 2024 with Hugging Face
- How-to-Fine-Tune-an-LLM-Part-2-Instruction-Tuning-Llama-2
- HuggingFace Alignment Handbook
- Niels Rogge's SFT Tutorial
- this Fine-tune Llama 2 Notebook
But it's also possible to find examples with default prompt-masking:
- this FastChat example
- PyTorch/torchtune
- Axolotl (mentioned above)
Spoiler alert: this article does not attempt to settle this issue once and for all. It began as a humble investigation inspired by a simple idea – I wanted to compare fine-tuning with and without prompt masking, while in both cases separately tracking the validation set prompt loss and completion loss.
My hypothesis was this might yield useful insights into the prompt-masking question. Then I came across the concept of prompt-loss-weight, an elegant generalization of binary token-masking into real-valued token-weighting (the weighting happens inside the loss function, as we'll see).
Integrating a prompt-loss-weight (PLW) parameter into the fine-tuning pipeline enables a smoother, more fine-grained control over the influence of prompt tokens on the fine-tuning process. Simply put: PLW=0 equates to prompt-masking, while PLW=1 equates to no masking. In addition, using 0<PLW<1 allows one to smoothly modulate the influence of prompt tokens between these two extremes.
With this re-framing, the question of whether or not to mask prompt tokens is subsumed by the deeper question of how much to weight prompt tokens. The optimal weighting may vary depending on the specific use case and dataset. By adding prompt-loss-weight to your toolkit, you'll gain the flexibility to experiment with different weighting strategies, leading to more effective fine-tuning outcomes tailored to your particular needs.
Since I couldn't find any implementations of prompt-loss-weight, I decided to try implementing it myself. I'll guide you through the customizations I had to make to several parts of the standard HuggingFace LLM toolset to make this work. Afterwards, we'll use our updated toolset to explore the original questions about prompt tokens by running some fine-tuning experiments on the RACE dataset (a multiple choice QA dataset hosted on HuggingFace).
Some LLM Background
LLMs operate on tokens rather than words. For the purposes of this article we will use these two terms interchangeably, but it's good to note the difference. Tokens are defined as frequently occurring sequences of characters, and often coincide roughly with words (and may even include the preceding space as well). A fun exercise is to play around with the GPT-4 tokenizer, which I used to generate the following example (color-coding reveals the underlying tokens):

The type of generative LLMs that most of us work with everyday are next-token-prediction machines. They have been trained (sometimes referred to as pre-training) on massive amounts of human generated text (books, newspapers, the internet, etc.) so that when fed a random snippet of sensible text, they are very good at predicting what the next word should be. This is sometimes referred to as Causal Language Modeling. When applied repeatedly, this autoregressive text generation process can generate very human-like sentences, paragraphs, articles, and so on.
Often we will want to take one of these foundation model LLMs, that have been pre-trained on massive amounts of text (like the Llama family of models from Meta), and continue the training a bit further, i.e. fine-tune them on a much smaller text dataset. This practice has roots in the broader field of transfer learning.
The goal here is to gently tweak, or customize, the LLM's next-token-prediction behavior without majorly disrupting or corrupting the basic underlying "intelligence" that is manifested in the model weights – this leads to LLMs that retain most of the emergent abilities of the foundation model (like reading comprehension, the ability to converse, to reason…), but are now specialized for a specific task. For example, instruction-tuning means fine-tuning an LLM so that it can follow instructions.
There are many instruction-tuning datasets available on HuggingFace datasets hub, organized by task. Some datasets are for question answering, or text summarization. In the vast majority of cases, all these datasets share the same basic underlying schema, each data sample containing:
- a prompt, a.k.a. the instruction
- a completion, a.k.a. the response
In this setting, the goal of fine-tuning is to increase (ultimately maximize) the probability that the LLM will generate the completion when given the prompt as input. In other words, the response "completes" the prompt. We rarely, if ever, have any interest in altering the probability that the LLM will generate the prompt itself… which is just the input to the LLM.

Consider text summarization, for instance. A typical prompt might consist of an instruction to summarize a long news article together with the article itself, and the completion would be the requested summary (see the EdinburghNLP/xsum dataset on HuggingFace). The goal of fine-tuning a foundation LLM on this dataset would be to increase the likelihood that the LLM will generate the summary when given the instruction+article, not that the LLM will generate the article itself, or generate the second half of the article if shown the first half.
However, a popular approach that has emerged for fine-tuning LLMs on prompt-completion style datasets is to largely ignore the prompt-completion distinction, and fine-tune the model on the entire text sequence – basically just continuing the same process that was used to pre-train the foundation model, even though instruction tuning has a quite different goal from pre-training. This leads to teaching the LLM to generate the prompt as well as the completion.
I'm not entirely sure why this is the case, but most likely this habit was simply inherited from older, foundation model training protocols, where there was originally no such distinction. From what I can gather, the basic attitude seems to be: well, what's the harm? Just fine-tune on the entire sequence, and the model will still learn to do what you want (to generate the completion given the prompt)… it will just learn some extra stuff too.
Prompt-Masking -vs- Prompt-Dampening
The most obvious solution would be to eliminate (or zero-mask) the prompt tokens out of the learning process. PyTorch allows for manually masking input tokens from training, through the ignore_index=-100
parameter of the CrossEntropyLoss function. Setting all the label ids corresponding to the prompt tokens to -100
**** forces CrossEntropyLoss to ignore these tokens in the loss computation, which results in training only on the completion tokens (in my opinion, this is a very poorly documented feature – I only stumbled upon it by accident – there‘s a reference buried in here in the Llama documentation).
By itself, this is not really a solution to prompt-masking. It's only a means for masking arbitrary tokens once those tokens have been located by some other means. Some of the prompt-masking references listed earlier employ this technique, while others explicitly create a binary-mask to accomplish the same thing. While useful, this solution is still a binary switch rather than the continuous dial that prompt-loss-weight allows.
However, this begs the question: if prompt-masking does improve instruction-tuning, what's the point of having a non-zero prompt-loss-weight at all? Why would we want to merely dampen the influence of prompt tokens rather than eliminate it completely?
Recently a paper was posted on arxiv titled Instruction Fine-Tuning: Does Prompt Loss Matter? The authors suggest that a small amount of prompt learning may act as a regularizer during fine-tuning, preventing the model from over-fitting the completion text. They hypothesize:
…that [a non-zero] PLW provides a unique regularizing effect that cannot be easily replaced with other regularizers…
Even the folks at OpenAI seem to acknowledge the benefits of using a small but non-zero prompt-loss-weight. Apparently they once exposed this very PLW parameter through their fine-tuning API, and there's still some documentation about it online, in which it's noted that:
a small amount of prompt learning helps preserve or enhance the model's ability to understand inputs (from Best practices for fine-tuning GPT-3 to classify text)
although they have since removed this parameter. According to the old docs, though, they used a default value of PLW=0.1
(10%), meaning prompt tokens get weighted 1/10ᵗʰ as much as completion tokens.
Generation Ratio
In the previously mentioned paper (Instruction Fine-Tuning: Does Prompt Loss Matter?) the authors introduce a useful quantity. Given an instruction dataset, they define the Generation Ratio, or Rg:
the generation ratio Rg is the ratio of completion length to prompt length. We then divide instruction data into two broad categories. Data with Rg<1 are short-completion data, and data with Rg >1 are long-completion data. When applied to an entire dataset, we take R̅g to be the mean completion-prompt ratio.
For datasets with small R̅g values (i.e. the completion is shorter than the prompt) they found that PLW actually does matter (i.e. using the wrong PLW value can degrade performance). And if you think about it, many common instruction-tuning datasets have this property of having a shorter completion length than prompt length, almost by design (think: text summarization, information extraction)
As a fun exercise, I computed the R̅g values for several popular instruction datasets on HuggingFace (code here):
- 7.6 | Alpaca (general instruction)
- 6.0 | OpenHermes (general instruction)
- 3.6 | Python-18k (code instruction)
- 2.0 | Databricks-Dolly-15k (general instruction)
- 1.1 | OpenOrca (general instruction)
- 0.2 | SAMSum (text summarization)
- 0.1 | XSum (text summarization)
- 0.01 | RACE (QA/multiple choice)

When summarizing any set of values by its average, its good practice to look at the full distribution of values as a sanity check. The arithmetic mean can be misleading on data that is highly skewed or otherwise deviates from being roughly normally distributed. I plotted histograms showing the full Rg distribution for each dataset (top row). The bottom row shows the same histograms but with the x-axis log-scaled:

These plots suggest that when a dataset's Rg distribution covers multiple orders of magnitude or has non-negligible representation in both the Rg>1 and Rg<1 regions (such as in the case with OpenOrca and other datasets with R̅g>1) the distribution can become highly skewed. As a result, the arithmetic mean may be disproportionately influenced by larger values, potentially misrepresenting the distribution's central tendency. In such cases, computing the mean in log-space (then optionally transforming it back to the original scale) might provide a more meaningful summary statistic. In other words, it could make sense to use the geometric mean:
The RACE Reading Comprehension Dataset
Based on the above R̅g table, I decided the RACE ReAding Comprehension Dataset from Examinations (R̅g=0.01) would be a good candidate for investigation. Multiple choice QA seemed like an ideal test-bed for exploring the effects of prompt-masking, since the prompt is naturally very long relative to the completion. Regardless of prompt length, the completion is always 1 character long, namely A, B, C or D (if you ignore special tokens, delimiters, etc). My hunch was that if there are any effects from modulating prompt token weights, they would certainly be noticeable here.
As stated in the dataset card:
RACE is a large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students. The dataset can be served as the training and test sets for machine comprehension.
The QA schema is simple: the prompt presents a question, possibly some context (the article field), and then lists four options. The completion (answer) is always one of: A, B, C, D. This dataset viewer hosted on HuggingFace allows browsing the full set, but here's a small example:

Cross Entropy Loss
Before we jump into the full implementation of prompt-loss-weight, and try it out on the RACE data, we need a basic understanding of loss and where it comes from. Simply put, loss is a measure of how well our model (LLM) "fits" (explains, predicts) our data. During fine-tuning (and also pre-training), we "move" the model closer to the data by tweaking the network weights in such a way that decreases the loss. The chain rule (of calculus) gives us a precise algorithm for computing these tweaks, given the loss function and the network architecture.
The most common loss function in LLM fine-tuning is called Cross Entropy Loss (CEL). For this reason, most discussions of CEL are framed around the definition of cross-entropy, which comes from information theory. While it's true that "cross-entropy" is right there in the name, a more intuitive understanding can be achieved when approaching CEL through the lens of maximum likelihood estimation (MLE). I'll try to explain it from both angles.
We have already established that LLMs are wired for next token prediction. What this means is that the LLM is basically just a mathematical function that takes as input a sequence of tokens, and outputs a conditional probability distribution for the next token over the entire token vocabulary V. In other words, it outputs a vector of probability values of dimension |V| that sums to 1. (in set notation |S| denotes the number of elements, or cardinality, of set S)
Let's take a small toy example to illustrate how this works. Imagine that our training data contains the 4-token sequence: The bird flew away
. Given the first 3 tokens (The bird flew
), an LLM might output the following vector of probabilities for every possible 4ᵗʰ token – for the sake of simplicity, we'll imagine that the 5 candidate tokens listed (in magenta) are the only possibilities (i.e. |V|=5). The function p(⋅) represents the conditional probabilities output by the LLM (notice they sum to 1):

When training (or fine-tuning) an LLM on a token sequence, we step through the sequence token-by-token and compare the next-token-distribution generated by the LLM to the actual next token in the sequence, and from there we calculate the CEL for that token.
Notice here that the actual 4ᵗʰ token in the sequence (away
) does not have the highest probability in the table. During training, we would like to tweak the weights slightly so as to increase the probability of away
, while decreasing the others. The key is having the right loss function… it allows us to compute exactly how much to tweak each weight, for each token.
Once the loss is computed for each token, the final loss is computed as the average per-token-loss over all tokens. But first we must establish the formula for this per-token-loss.
Information Theory Interpretation
Continuing the toy problem, to compute CEL for the 4ᵗʰ token position, we compare the actual 4ᵗʰ token to the generated distribution p(⋅) over all 5 possible 4ᵗʰ tokens. In fact, we treat the actual 4ᵗʰ token as a distribution q(⋅) in its own right (albeit a degenerate one) that has a value of 1 for the token appearing in the data –away
– and a value of 0 for all other possible 4ᵗʰ tokens (this is sometimes called one-hot encoding).

The reason we contort the training data into this strange one-hot encoded probability representation q(⋅) is so we can apply the formula for cross-entropy, which is a measure of the divergence between two discrete probability distributions (BTW, not symmetric w.r.t. q,p):

where x indexes over all possible states (i.e. 5 tokens). This works out to:

So basically CEL is just using the q vector to select from the p vector the single value corresponding to the token that actually appears in the data –away
– (i.e. multiplying it by 1), and throwing away all other values (i.e. multiplying by 0). So we are indexing over all possible states (tokens) only to select one and ignore the rest.
MLE Interpretation
When fine-tuning an LLM, we seek the LLM weights θ that maximize the probability of the training data given those weights, often called the likelihood of the weights ℒ(θ) = ℙ(D|θ). And so we require an expression for this quantity. Luckily, there's an easy way to compute this from next token probabilities, which the LLM already gives us.
Starting with the other chain rule (of probability), we decompose the joint probability of a token sequence S into a product of conditional probabilities:

This decomposition establishes the connection between next-token-prediction and the joint probability of the full token sequence – the joint probability is just the product of all the conditionals.
Using i to index over the tokens of a token sequence S = (t₁,t₂,t₃,…, tᵢ ,…), we'll use the following shorthand to denote the conditional probability output by an LLM for the iᵗʰ token in a sequence, given the LLM weights θ and the previous i-1 tokens:

It should be emphasized that pᵢ is not a vector here (i.e. a distribution over all possible next tokens) but represents only the probability computed for the actual iᵗʰ token, i.e. the yellow highlighted row in the above example.
If we take the logarithm of the joint probability of a sequence, a product becomes a sum (since log is monotonic, this doesn't affect optimization):

Now we can connect the final sum-of-logs expression (right here☝)️ to the formula for Average Cross Entropy Loss L over a token sequence:

which is the causal language model objective function. Often the "Average" is dropped from the name, and it's just called "Cross Entropy Loss," but it's good to remember that CEL is technically computed at the token level, and then averaged across tokens. From this final expression it should hopefully be clear that minimizing the CEL is equivalent to maximizing the probability of the token sequence, which is what MLE seeks.
One convenience resulting from the form of this expression is that it is very easy to modify if we want to compute the loss over any subset of the tokens. Recall that we may sometimes be interested in finding the LLM weights θ that maximize the probability of the completion given the prompt:

We could easily adjust the loss for this scenario by simply averaging only over the completion tokens. If we use "