Domain Adaptation of A Large Language Model

Author:Murphy  |  View: 21949  |  Time: 2025-03-23 12:04:16
Image from unsplash

Large language models (LLMs) like BERT are usually pre-trained on general domain corpora like Wikipedia and BookCorpus. If we apply them to more specialized domains like medical, there is often a drop in performance compared to models adapted for those domains.

In this article, we will explore how to adapt a pre-trained LLM like Deberta base to medical domain using the HuggingFace Transformers library. Specifically, we will cover an effective technique called intermediate pre-training where we do further pre-training of the LLM on data from our target domain. This adapts the model to the new domain, and improves its performance.

This is a simple yet effective technique to tune LLMs to your domain and gain significant improvements in downstream task performance.

Let's get started.

Step 1: The Data

First step in any project is to prepare the data. Since our dataset is in medical domain, it contains the following fields and many more:

image by author

Putting the full list of fields here is impossible, as there are many fields. But even this glimpse into the existing fields help us to form the input sequence for an LLM.

First point to keep in mind is that, the input has to be a sequence because LLMs read input as text sequences.

To form this into a sequence, we can inject special tags to tell the LLM what piece of information is coming next. Consider the following example: name:John, surname: Doer, patientID:1234, age:34 , the is a special tag that tells LLM that what follows are information about a patient.

So we form the input sequence as following:

Image by author

As you see, we have injected four tags:

  1. : to contain information about the patient
  2. : to contain information regarding the hospital
  3. : to contain information regarding the individual events the patient has done in the hospital.
  4. : this is to enclose all events a patient has had in a hospital.

And inside each tag block, we are containing attributes as a key:value pair.

Note, for a given patient and hospital, we are sorting events by timestamp and concatenating them together. This forms a time-ordered sequence of visits the patient has had in the hospital.

The advantage of the special tags is that after training the LLM if I want the embedding for a patient, I can get it via retrieving the embedding for tag. Similarly, if we want to have an embedding for a patient such that it acts as a profile for the patient we can retrieve the embedding for tag, as this tag contain all events the patient has had with a hospital.

Let's assume our data is stored in s3; where the data schema is only one column called "text" and each record is a sequence of above format in "text" column. We load the data from s3 using below code:

import s3fs
import random

files = {}
fs = s3fs.S3FileSystem()

train_path = "s3://bucket/train/*.parquet"
s3_files = ["s3://" + p for p in fs.glob(train_path)]
random.shuffle(s3_files)
files["train"] = s3_files

validation_path = "s3://bucket/test/*.parquet"
s3_files = ["s3://" + p for p in fs.glob(validation_path)]
random.shuffle(s3_files)
files["validation"] = s3_files

from datasets import load_dataset
raw_datasets = load_dataset("parquet", data_files=files, streaming=False, use_auth_token = True)
print(raw_datasets)

and the raw_datasets looks as following:

image by author

Step 2: Coding

First, install the requirements via following command:

!pip install -r requirements.txt 

The requirements.txt file looks as following:

pytest==7.4.2
pytest-cov==4.1.0
datasets==2.13.0
huggingface-hub==0.16.4
tensorboard==2.14.0
networkx==2.6.3
numpy==1.22.4
pandas==2.0.3
s3fs==2023.5.0
tokenizers==0.13.3
tqdm==4.66.1
transformers==4.31.0
evaluate==0.4.0
accelerate==0.23.0
bitsandbytes==0.41.1
trl==0.5.0
peft==0.4.0
pyarrow==13.0.0
pydantic==1.10.6
deepspeed==0.9.0

To write the code, we have to define model arguments, data arguments and training arguments. We then need to define the model and put it in PEFT (parameter efficient) setting if we want to train it via PEFT.

First, we define the input arguments for data, for model and for training.

Model Arguments

Model arguments are arguments that specify which model/tokenizer we are going to train or fine tune. The class below, implements these as a dataclass . We will get an instance from this later to pass our choices.

The most important field here is model_name_or_path but for completeness we keep all arguments.

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to use
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )

    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

PEFT Arguments

Below we have the arguments pertaining the parameter efficient training. This uses the lora package for low-rank adaptation.

@dataclass
class PEFTArguments:
    """
    Arguments pertaining to what training arguments we pass to trainer.
    """
    lora_r: Optional[int] = field(
        default=0, metadata={"help": "LoRA bottleneck dim. This value must be > 0 to utilize LoRA."}
    )

    lora_alpha: Optional[int] = field(
        default=32, metadata={"help": "LoRA alpha"}
    )

    lora_dropout: Optional[float] = field(
        default=0.1, metadata={"help": "LoRA dropout probability"}
    )

    target_modules: Optional[str] = field(
        default="", metadata={
            "help": "Target modules to use for LoRA adaptation (must be input as a comma delimited string)"
        }
    )

Data Arguments

These are arguments pertaining to what data we are going to input our model for training and evaluation.

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    max_seq_length: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated."
            )
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
    )
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )

    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
    additional_special_tokens: Optional[str] = field(
        default=",,,,,,,,,,,", 
        metadata={"help": "Comma seperated list of special tokens to add to tokenizer."}
    )

    additional_tokens: Optional[str] = field(
        default=None, metadata={"help": "Comma seperated list of additional tokens to add to tokenizer."}
    )

    masking_strategy: Optional[str] = field(
        default="word", metadata={
            "help": (
                "Type of masking strategy used for MLM. "
                "Note that white_space strategy only supports BPE tokenizer (e.g., gpt2, roberta)."
            ),
            "choices": ["word", "token", "span", "white_space", "token_sep"]
        }
    )

    masking_span_p: Optional[float] = field(
        default=0.2, metadata={"help": "The masking span length follows a geometric distribution, p is the parameter."}
    )

    masking_sep_token: Optional[str] = field(
        default=":", metadata={"help": "Token used to divide input into different spans for MLM."}
    )

    masking_prefix_flag: Optional[bool] = field(
        default=False, metadata={"help": "Mask entities in the prefix as whole."}
    )

    entity_sep: Optional[str] = field(
        default=",,,,,", 
        metadata={
            "help": (
                "Comma seperated list of separator tokens, used when masking_prefix_flag = True."
                "The format is: 'start_token_1,end_token_1,start_token_2,end_token_2'"
            )
        }
    )

    def __post_init__(self):
        if self.streaming:
            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

Initializing Arguments

Next, we pass our input arguments in above classes and initialize all arguments:

model_args = ModelArguments(model_name_or_path='microsoft/deberta-base')

data_args = DataTrainingArguments(masking_strategy = 'token_sep', 
                                masking_span_p = 0.2,
                                masking_sep_token = ',',
                                masking_prefix_flag= True,
                                streaming = False,
                                mlm_probability = 0.15,
                                pad_to_max_length = False,
                                line_by_line = True,
                                additional_tokens = None,
                                 )

training_args = TrainingArguments(output_dir = './output', 
                                  max_steps= 2000,
                                  eval_steps= 200,
                                  logging_steps=200,
                                  do_train= True,
                                  do_eval= True,
                                  evaluation_strategy='steps',
                                  remove_unused_columns = False,
                                  label_names = ["labels"],
                                  per_device_train_batch_size = 4,
                                  per_device_eval_batch_size = 4,
                                  overwrite_output_dir = True 
                                 )

As you see, we are loading deberta-base model so we will domain adapt this model to the medical domain.

Data Tokenization

In this part we tokenize, collate and group the data.

The tokenization part loads a pre-trained tokenizer related to our model, and adds the special tokens.

tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "use_auth_token": model_args.use_auth_token,
        "trust_remote_code": model_args.trust_remote_code,
    }

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)

tokenizer.add_special_tokens({'additional_special_tokens': data_args.additional_special_tokens.split(",")})

if data_args.additional_tokens:
    tokenizer.add_tokens(data_args.additional_tokens.split(","))

We then load the model and update the embedding layer of the model with number of tokens in the vocabulary of the tokenizer:

config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": model_args.use_auth_token,
        "trust_remote_code": model_args.trust_remote_code,
    }

config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)

model = AutoModelForMaskedLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=model_args.use_auth_token,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )

embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

We also update the context length of the model in the tokenizer:

# Check if tokenizer.model_max_length is undefined
if tokenizer.model_max_length > 1e9:
    tokenizer.model_max_length = model.config.max_position_embeddings

We then write the tokenize function and apply it on the dataset. Our data has one column called "text". We tokenize this column and remove it from the output:


tokenized_datasets = raw_datasets.map(
      lambda example: tokenizer(example['text']),
      batched=True,
      remove_columns=["text"],
  ) 

The tokenized_datasets looks as following:

Image by author

If we check the length of the input_ids for each record in segment train, we see that records have input_ids of different length.

l = []
for item in tokenized_datasets['train']:
    l.append(len(item['input_ids']))
print(set(l))

and it prints the following long list:

2243,1204, 2310, 2402, 645, 2319, ....

The point is every record has a different sequence length. We can pad them or truncate them or group them into sequences of size context length to make sure they are of the same size.


data_args.max_seq_length = tokenizer.model_max_length
max_seq_length = data_args.max_seq_length

# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, and if the total_length < max_seq_length  we exclude this batch and return an empty dict.
    # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
    total_length = (total_length // max_seq_length) * max_seq_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

tokenized_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
)

now if you repeat the exercise:

l = []
for item in tokenized_datasets['train']:
    l.append(len(item['input_ids']))
print(set(l))

it only prints {512} , because all sequences are of length 512.

Next we define the data collator:

@dataclass    
class MaskingDataCollator:
    tokenizer: PreTrainedTokenizerBase
    wwm_probability: Optional[float] = 0.2

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        for i, feature in enumerate(features):
            mask_ids = feature.pop("mask_ids")

            # word_id to token index mapping
            mapping = self.word_mapping(mask_ids)
            # Randomly mask words
            if "labels" not in feature.keys():
                labels = feature["input_ids"].copy()
            else:
                labels = feature["labels"]
            feature["labels"], _ = self.random_masking_whole_word(mapping, feature["input_ids"], labels, self.tokenizer.mask_token_id)

        batch = default_data_collator(features)
        return batch

    def word_mapping(self, mask_ids):
        # Create a map between words and corresponding token start and end inds
        mapping = defaultdict(list)
        current_word_index = -1
        current_word = None
        for i, word_id in enumerate(mask_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(i)

        return mapping

    def random_masking_whole_word(self, mapping, input_ids, labels, mask_token_id):

        mask = np.random.binomial(1, self.wwm_probability, size=len(mapping))

        # masked at least one mask_id
        if sum(mask) == 0:
            rn_i = random.choice(range(len(mask)))
            mask[rn_i] = 1

        new_labels = [-100] * len(labels)

        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = mask_token_id
        return new_labels, input_ids

data_collator = MaskingDataCollator(
        tokenizer, 
        wwm_probability=data_args.mlm_probability
    )

Training The Model

If we are going to train in parameter efficient mode, we use the lora package as following:

peft_args = PEFTArguments()

peft_config = None
if peft_args.lora_r > 0:
    logger.info("Using LoRA for model adaptation...")
    peft_config = LoraConfig(
        r=peft_args.lora_r,
        lora_alpha=peft_args.lora_alpha,
        lora_dropout=peft_args.lora_dropout,
        target_modules=peft_args.target_modules.split(",") if peft_args.target_modules 
        else TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model.config.model_type]
    )

    model = get_peft_model(model, peft_config)

and then we continue to write the compute_metrics function for computing the metric of choice. Here we use accuracy.

train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    # this is to ensure we compute loss on masked entities
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

Note mask = labels != -100 is to ensure we compute the loss at masked entities. For entities which are masked their corresponding label is a positive ID (which is the input ID of the original token in that position). For entities which are not masked and therefore we don't want to compute model's performance on them, their corresponding label is set to -100.

Defining mask = labels != -100 produces mask as boolean vector and it is True only where entities are masked.

Logit processing: Then, we preprocess logits. The following function returns the index at which maximum logit was occured. This will be the prediction of the model.

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

Trainer:

This is where we initialize the trainer object. We kick off the training via trainer.train() soon.

# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
    if training_args.do_eval and not is_torch_tpu_available()
    else None,
)

Let's kick off training via trainer.train() and save artifacts:

train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too for easy upload
metrics = train_result.metrics

if not data_args.streaming:
    max_train_samples = (
        data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
    )
    metrics["train_samples"] = min(max_train_samples, len(train_dataset))
else:
    metrics["max_steps"] = training_args.max_steps

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

And it prints the following output:

Image by author

First of all, note the trainer.train() returns an object of type TrainOutput called train_result . This object looks like this:

TrainOutput(global_step=2000, training_loss=3.754445343017578, metrics={'train_runtime': 794.2916, 'train_samples_per_second': 10.072, 'train_steps_per_second': 2.518, 'total_flos': 2454016352256000.0, 'train_loss': 3.754445343017578, 'epoch': 0.78, 'train_samples': 10261})

and note with metrics = train_result.metrics we are accessing the metrics dictionary inside it. We will pass this dictionary later to trainer.log_metrics() and trainer.save_metrics().

The trainer.log_metrics()prints out a report as following:

Image by author

Second, note that we are saving few things:

  • trainer.save_model(): This saves the model and its tokenizer. We can reload it later using from_pretrained().
  • trainer.save_state(): This saves the trainer state since trainer.save_model() does not save the state. This statement creates a _trainerstate.json file that looks as following:
Image by author
  • trainer.save_metrics("train", metrics) : This saves metrics into a json file for that train split, e.g. train_results.json. This file looks as following:
Image by author

Conclusion

In this post, we reviewed how to take a pre-trained LLM and adapt it to a new domain such as medical, financial etc. We took a pre-trained deberta base model from huggingFace and continued pre-training it on medical data. We saved the trained model in a directory for customized evaluation.


If you have any questions or suggestions, feel free to reach out to me: Email: [email protected] LinkedIn: https://www.linkedin.com/in/minaghashami/

Tags: AI Deep Learning Domain Adaptation Hugging Face Llm

Comment