PyTorch 2.0: How to Speed Up Model Training



With the unveiling of PyTorch 2.0 in December 2022, the machine learning landscape buzzed with anticipation. Significant enhancements in model training and inference speed were promised. This leads us to ask two essential questions: what design changes in PyTorch 2.0 enable these advancements, and does PyTorch 2.0 actually deliver on its promises?

In this article, we will analyze PyTorch 2.0’s core features, exploring how they differ from previous versions and why these changes were made.

In addition to discussing the technical aspects of PyTorch 2.0, we will evaluate its performance in practical use cases, specifically in the context of SuperGradients, Deci’s open-source deep learning training library for PyTorch-based models. To conclude, we will share practical advice and best practices for using PyTorch 2.0 to accelerate model training.

What Is PyTorch 2.0?

PyTorch 2.0 is an update to the widely-used open-source machine learning library, PyTorch. It introduces several new features aimed at improving performance and accelerating model training. Most notably, it includes torch.compile, which is a function that takes your model and returns a compiled model, capable of running faster.  

Beyond torch.compile, PyTorch 2.0 also brings enhancements to the PyTorch Transformer API, optimizations for Graph Neural Network (GNN) inference and training performance, and distributed computing support. 

How Is torch.compile Different From Other Compilation Methods?

torch.compile marks a departure from traditional compilation methods. It features a partial graph capture mechanism, a novel approach that allows for the combination of compiled and eagerly-executed graphs.. 

This innovative feature works by combining sections of code it understands and can optimize into computational graphs, while leaving sections it doesn’t comprehend to run as standard Python code. This partial compilation allows torch.compile to provide performance improvements without requiring a full translation of a codebase into a static language.

The Motivation Behind torch.compile

The motivation behind torch.compile centers around the desire to exploit the impressive capabilities of modern hardware accelerators, while still preserving the ease-of-use and flexibility that has come to define PyTorch.

Since PyTorch’s inception in 2017, the computational power of hardware accelerators, such as GPUs, have become about 15 times more efficient in computation and twice as fast in memory access.  To ensure that eager execution was sufficiently fast for the new hardware, a significant portion of PyTorch’s internals had to be shifted to C++.   . While this change boosted performance to match the rapidly advancing hardware, it also made PyTorch less flexible and accessible, increasing the barriers to code contributions.

In the meantime, PyTorch held back from adopting graph mode execution, a common method of harnessing hardware acceleration. This was due to its desire to preserve the ease and flexibility of eager execution, a feature much appreciated by its users. Graph mode execution, although offering speedups by compiling and dispatching large blocks of operations together, necessitated compiling entire models, a task that often proved difficult or impossible.

In response to these challenges, PyTorch introduced torch.compile in version 2.0. This new feature allows for partial compilation, compiling sections of the code that are compatible with graph execution, and dispatching them as a batch to the GPU for efficient hardware usage. Meanwhile, parts of the code that resist compilation can safely fall back to Python eager execution.

Additionally, the components of torch.compile are written in Python. This design decision makes the system highly accessible and easier to debug, thereby addressing the lost flexibility that resulted from migrating parts of PyTorch to C++.

In essence, torch.compile finds a balance between the user-friendly experience of Python eager execution and the high performance offered by graph execution and hardware accelerators. It provides a high-speed execution environment that aligns with modern hardware capabilities, while still maintaining the PyTorch environment that users have come to know and love.

An Under-The-Hood Look At torch.compile

Underpinning torch.compile are four key technologies: TorchDynamo, AOTAutograd, PrimTorch, and TorchInductor. Each of these technologies plays a significant role in the functioning of torch.compile. 

TorchDynamo: Acquiring Graphs Reliably and Fast

TorchDynamo works by interpreting Python bytecode symbolically, converting it into a graph of tensor operations. If it comes across a segment of code that it cannot interpret, it defaults to the regular Python interpreter. This approach ensures that it can handle a wide range of programs while providing significant performance improvements. 

AOTautograd: Reusing Autograd for Ahead-of-Time Graphs

AOT Autograd is PyTorch 2.0’s automatic differentiation engine. Its function is to produce backward traces in an ahead-of-time (AOT) fashion, enhancing the efficiency of the differentiation process. AOTAutograd uses PyTorch’s torch_dispatch mechanism to trace through the existing PyTorch autograd engine, capturing the backward pass ahead-of-time. This enables acceleration of both the forward and backward pass.

TorchInductor: Generating High-speed Code for Accelerators and Backends

TorchInductor is a deep learning compiler that translates intermediate representations into executable code. It take the computation graph generated by TorchDynamo and converts it into optimized low-level kernels. For NVIDIA and AMD GPUs, it employs OpenAI Triton as a fundamental component.  

PrimTorch: Reducing Operator Complexity 

PrimTorch is an initiative designed to streamline the creation of PyTorch features or backends. It accomplishes this by significantly decreasing the number of PyTorch operators from an overwhelming total of over 2000 to a manageable set of around 250 fundamental operators. It thereby simplifies the process of writing a backend for PyTorch by defining smaller and stable operator sets. PyTorch programs can be consistently lowered to these operator sets, easing the development of features or backends for PyTorch.

Additional PyTorch 2.0 Enhancements 

In addition to torch.compile, PyTorch 2.0 introduces a host of other significant enhancements.

Optimized Implementation of the Transformer API

Transformers are at the core of many state-of-the-art models in NLP, and training these models can be resource-intensive. PyTorch 2.0 addresses this by optimizing the implementation of its Transformer API. The PyTorch team recently reported using its new Transformer API to accelerate nanoGPT. They achieved a 27% speed up in training time per batch (measured with Nvidia A100 GPUs), going from a ~143ms/batch baseline to ~113 ms/batch. 

The result is a significant reduction in the computational resources required for training and deploying Transformer models. The optimized implementation enables more affordable training and deployment of such models, making them more accessible across the industry. 

GNN Inference and Training Performance on CPUs

Graph Neural Networks (GNNs) have seen significant adoption in the machine learning community, and PyTorch 2.0 comes with critical optimizations for GNN inference and training performance, specifically on CPUs. This makes GNNs more effective and efficient for tasks such as social network analysis, molecular chemistry, and recommendation systems. More detail on PyTorch 2.0’s GNN inference and training optimization can be found in this article. 

Distributed Computing Support

Another improvement in PyTorch 2.0 is its extended support for distributed computing. This feature allows for the training of machine learning models across several machines, empowering developers to handle larger and more intricate tasks without requiring supercomputer resources.

To facilitate this, PyTorch introduced DistributedTensor (DTensor), an effort to streamline the process of distributed computation within the Single Program Multiple Devices (SPMD) paradigm. DTensor offers a set of fundamental distributed tensor primitives that can manage both sharded and replicated parallelism strategies. This is instrumental in enabling PyTorch Tensor Parallelism and further exploration of advanced parallelism.

Moreover, DTensor provides a uniform approach for saving and loading the state_dict for distributed checkpointing, even in situations involving complex tensor distribution strategies, such as combining tensor parallelism with parameter sharding in Fully Sharded Data Parallelism (FSDP).

Accelerating Training Time with SuperGradients and torch.compile

So far we’ve delved into a theoretical discussion about the motivation for and design of PyTorch 2.0, with a focus on torch.compile. But how do these translate into real world training acceleration results? We’ll answer this question by looking at the training speedups achieved using the torch.compile API in SuperGradients, Deci’s open-source deep learning training library for PyTorch-based models.

The Results

We measured the training time of one epoch, first without and then with torch.compile. One epoch includes interaction over training and validation datasets, loss & metric computation – essentially all the steps that are performed during training.

The following table shows the relative improvement (reduction) of training time for several models. It’s important to note that the improvement varies depending on the model architecture, dataset, and training hyperparameters. 

torch.compile’s Impact: Training Speed Metrics

Task Recipe Baseline (1 GPU) Baseline (8 GPU) 1 GPU With Compile 8 GPU With Compile Improvement, % (1 GPU) Improvement, % (8 GPU)
Semantic Segmentation cityscapes_pplite_seg75 270.63 49.95 119.11 35.91 56% 18%
Semantic Segmentation cityscapes_regseg48 125.14 44.959 108.57 44.55 13.20% 0.90%
Semantic Segmentation cityscapes_segformer 199.97 46.21 162.52 43.71 18.70% 5.40%
Semantic Segmentation cityscapes_stdc_seg75 425.19 73.07 153.16 45.89 63.90% 37.19%
Semantic Segmentation cityscapes_ddrnet 226.51 51.78 174.11 48.29 23.10% 7.30%
Object Detection coco2017_yolo_nas_s 1509 384.41 1379 376.1 8.60% 2.42%
Object Detection coco2017_yolo_nas_m 2363 537.24 2090 508.4 11.50% 0.19%
Object Detection coco2017_yolo_nas_l 3193 764.17 2869 745.58 10.14% 2.43%

Both experiments were run on 8x 3090 GPUs using PyTorch 2.0 with CUDA 11.8. Training was done for 5 epochs and median value was picked to compute the speedup. All experiments conducted with mixed precision (AMP) enabled, and SyncBN and EMA disabled. Improvement percentage is computed as follows: 100 * (baseline_time – compile_time) / baseline_time.

How to Leverage the Use of Models Compiled with torch.compile in SuperGradients

To leverage the use of torch.compile in SuperGradients, you need to pass the torch_compile: True option to training hyperparametes. 

python -m super_gradients.train_from_recipe --config-name=... training_hyperparams.torch_compile=True

In the YAML recipe:

# my_training_recipe.yaml
  torch_compile: True 
  torch_compile_mode: default | reduce-overhead | max-autotune

Or programmatically:

from import Trainer

trainer = Trainer(
    training_hyperparams = {
        "torch_compile": True,

Common Pitfalls and Tips for Using torch.compile: 

From working with torch.compile, we’ve learned to avoid certain pitfalls and adjust settings for optimal results. Below, we share these key takeaways.

    1. Don’t use SyncBN with torch.compile.

    SyncBN, short for Synchronized Batch Normalization, is a technique used in PyTorch for applying batch normalization across multiple devices in a distributed training setup. In distributed training, where a neural network is trained across multiple machines or GPUs, SyncBN helps to ensure consistent and synchronized normalization statistics across different devices. Usually it provides a slight accuracy boost for trained modes thanks to the more accurate batch statistics.

    SyncBN mechanism in PyTorch incurs a synchronization point in each BatchNorm layer of the network. Which breaks an execution graph, penalizing a performance gain from the torch.compile itself.. Modern CNN architectures may contain hundreds of BN layers, therefore the synchronization overhead adds up and negates performance gain of torch.compile completely.

    The rule of thumb – do not use SyncBN and torch.compile simultaneously.

      2. Don’t use EMA with torch.compile.

      Note that EMA currently is not supported with torch.compile in SG. We will address this in the new release of SG, but as of now the rule of thumb is not to use EMA and torch.compile simultaneously.

        3. You may need to reduce batch size during training by quite a lot (Up to 2x).

        A performance gain of torch.compile does not come for free. We observed increased VRAM consumption when using torch.compile. This can be explained by the optimizations that pytorch did to the model. In reduce-overhead optimization mode the goal is to reduce unnecessary computations. One way to achieve this is to keep intermediate buffers in memory during forward/backward instead of recomputing them. While this gives the performance boost, it may cost 20-30% more VRAM compared to default training mode. In some cases the increased VRAM consumption could be even more extreme. 

        The rule of thumb – if torch.compile fails with cryptic errors – try decreasing batch size.

          4. Training with mixed precision gives the best performance boost.

          Mixed precision training, is a technique used in PyTorch to accelerate training and inference by leveraging the benefits of both single-precision (FP32) and lower-precision data type (such as FP16). During training weights of the model are kept as FP32, while computation is performed in FP16 mode.  In Pytorch this is available since the 1.6 release under Automated Mixed Precision (AMP) brand.

          We observed that using torch.compile with AMP gives us the best performance. 

          The rule of thumb – Training model with AMP and torch.compile gives you the best performance gain.

          Boosting Efficiency with torch.compile

          With the introduction of torch.compile, PyTorch 2.0 can significantly enhance the speed and efficiency of model training. The use of this feature ensures that more of your original Python code now runs at the full speed of a compiled language. This considerable performance boost makes the training process faster and more cost-effective, which is a major benefit for projects that involve extensive or complex machine learning models.

          If you’re ready to give it a try, head over to SuperGradients, optimize your model with torch.compile and start training!

          You May Also Like

          Mastering LLM Adaptations: A Deep Dive into Full Fine-Tuning, PEFT, Prompt Engineering, and RAG

          15 times Faster than Llama 2: Introducing DeciLM – NAS-Generated LLM with Variable GQA

          Announcing Infery-LLM – An Inference SDK for LLM Deployment Redefining State-of-the-Art in LLM Inference

          The latest deep learning insights, tips, and best practices delivered to your inbox.

          Add Your Heading Text Here
          					from transformers import AutoFeatureExtractor, AutoModelForImageClassification
          extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
          model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")