Deep Learning Training on AWS Inferentia

Author:Murphy  |  View: 28122  |  Time: 2025-03-23 12:51:09

The topic of this post is AWS's home-grown AI chip, AWS Inferentia – more specifically, the second-generation AWS Inferentia2. This is a sequel to our post from last year on AWS Trainium and joins a series of posts on the topic of dedicated AI accelerators. Contrary to the chips we have explored in our previous posts in the series, AWS Inferentia was designed for AI model inference and is targeted specifically for deep-learning inference applications. However, the fact that AWS Inferentia2 and AWS Trainium both share the same underlying NeuronCore-v2 architecture and the same software stack (the AWS Neuron SDK), begs the question: Can AWS Inferentia be used for AI training workloads, as well?

Granted, there are some elements of the Amazon EC2 Inf2 instance family specifications (which are powered by AWS Inferentia accelerators) that might make them less appropriate for some training workloads when compared to the Amazon EC2 Trn1 instance family. For example, although both Inf2 and support high-bandwidth and low-latency NeuronLink-v2 device-to-device interconnect, the Trainium devices are connected in a 2D Torus topology rather than a ring topology which can potentially impact the performance of Collective Communication operators (see here for more details). However, some training workloads may not require the unique features of the Trn1 architecture and may perform equally well on the Inf1 and Inf2 architectures.

In fact, the ability to train on both Trainium and Inferentia accelerators would greatly increase the variety of training instances at our disposal and our ability to tune the choice of training instance to the specific needs of each DL project. In our recent post, Instance Selection for Deep Learning, we elaborated on the value of having a wide variety of diverse instance types for DL training. While the Trn1 family includes just two instance types, enabling training on Inf2 would add four additional instance types. Including Inf1 in the mix would add four more.

Our intention in this post is to demonstrate the opportunity of training on AWS Inferentia. We will define a toy vision model and compare the performance of training it on the Amazon EC2 Trn1 and Amazon EC2 Inf2 instance families. Many thanks to Ohad Klein and Yitzhak Levi for their contributions to this post.

Disclaimers

  1. Note that, as of the time of this writing, there are some DL model architectures that remain unsupported by the Neuron SDK. For example, while model inference of CNN models is supported, training CNN models is still unsupported. The SDK documentation includes a model support matrix detailing the supported features per model architecture, training framework (e.g., TensorFlow and PyTorch), and Neuron architecture version.
  2. The experiments that we will describe were run on Amazon EC2 with the most recent version of the Deep Learning AMI for Neuron available at the time of this writing, "Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720", which includes version 2.8 of the Neuron SDK. Being that the Neuron SDK remains under active development, it is likely that the comparative results that we achieved will change over time. It is highly recommended that you reassess the findings of this post with the most up-to-date versions of the underlying libraries.
  3. Our intention in this post is to demonstrate the potential of training on AWS Inferentia powered instances. Please do not view this post as an endorsement for the use of these instances or any of the other products we might mention. There are many variables that factor into how to choose a training environment which may vary greatly based on the particulars of your project. In particular, different models might exhibit wholly different relative price-performance results when running on two different instance types.

Toy Model

Similar to the experiments we described in our previous post, we define a simple Vision Transformer (ViT)-backed classification model (using the timm Python package version 0.9.5) along with a randomly generated dataset.

from torch.utils.data import Dataset
import time, os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from timm.models.vision_transformer import VisionTransformer

# use random data
class FakeDataset(Dataset):
  def __len__(self):
    return 1000000

  def __getitem__(self, index):
    rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
    label = torch.tensor(data=[index % 1000], dtype=torch.int64)
    return rand_image, label

def train(batch_size=16, num_workers=4):
  # Initialize XLA process group for torchrun
  import torch_xla.distributed.xla_backend
  torch.distributed.init_process_group('xla')

  # multi-processing: ensure each worker has same initial weights
  torch.manual_seed(0)
  dataset = FakeDataset()
  model = VisionTransformer()

  # load model to XLA device
  device = xm.xla_device()
  model = model.to(device)
  optimizer = torch.optim.Adam(model.parameters())
  data_loader = torch.utils.data.DataLoader(dataset,
                         batch_size=batch_size, num_workers=num_workers)
  data_loader = pl.MpDeviceLoader(data_loader, device)
  loss_function = torch.nn.CrossEntropyLoss()
  summ = 0
  count = 0
  t0 = time.perf_counter()

  for step, (inputs, target) in enumerate(data_loader, start=1):
    inputs = inputs.to(device)
    targets = torch.squeeze(target.to(device), -1)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    xm.optimizer_step(optimizer)
    batch_time = time.perf_counter() - t0
    if step > 10:  # skip first steps
      summ += batch_time
      count += 1
    t0 = time.perf_counter()
    if step > 500:
      break
  print(f'average step time: {summ/count}')

if __name__ == '__main__':
  os.environ['XLA_USE_BF16'] = '1'
  # set the number of dataloader workers according to the number of vCPUs
  # e.g. 4 for trn1, 2 for inf2.xlarge, 8 for inf2.12xlarge and inf2.48xlarge
  train(num_workers=4)
# Initialization command:
# torchrun --nproc_per_node=2 train.py

Results

In the table below we compare the speed and price performance of various Amazon EC2 Trn1 and Amazon EC2 Inf2 instance types.

Performance comparison of ViT-based classification model (By Author)

While it is clear that the Trainium-powered instance types support better absolute performance (i.e., increased training speeds), training on the Inferentia-powered instances resulted in ~39% better price performance (for the two-core instance types) and higher (for the larger instance types).

Once again, we caution against making any design decisions based solely on these results. Some model architectures might run successfully on Trn1 instances but break down on Inf2. Others might succeed on both but exhibit very different comparative performance results than the ones shown here.

Note that we have omitted the time required for compiling the DL model. Although this is only required the first time the model is run, compilation times can be quite high (e.g., upward of ten minutes for our toy model). Two ways to reduce the overhead of model compilation are parallel compilation and offline compilation. Importantly, make sure that your script does not include operations (or graph changes) that trigger frequent recompilations. See the Neuron SDK documentation for more details.

Summary

Although marketed as an AI inference chip, it appears that AWS Inferentia offers yet another option for training deep learning models. In our previous post on AWS Trainium we highlighted some of the challenges that you might encounter when adapting your models to train on a new AI ASIC. The possibility of training the same models on AWS Inferentia-powered instance types, as well, could increase the potential reward of your efforts.

Tags: Artificial Intelligence AWS Cloud Ml Engine Deep Learning Machine Learning

Comment