• Tutorials >
  • Fault-tolerant Distributed Training with torchrun
Shortcuts

Introduction || What is DDP || Single-Node Multi-GPU Training || Fault Tolerance || Multi-Node training || minGPT Training

Fault-tolerant Distributed Training with torchrun

Authors: Suraj Subramanian

What you will learn
  • Launching multi-GPU training jobs with torchrun

  • Saving and loading snapshots of your training job

  • Structuring your training script for graceful restarts

View the code used in this tutorial on GitHub

Prerequisites
  • High-level overview of DDP

  • Familiarity with DDP code

  • A machine with multiple GPUs (this tutorial uses an AWS p3.8xlarge instance)

  • PyTorch installed with CUDA

Follow along with the video below or on youtube.

In distributed training, a single process failure can disrupt the entire training job. Since the susceptibility for failure can be higher here, making your training script robust is particularly important here. You might also prefer your training job to be elastic, for example, compute resources can join and leave dynamically over the course of the job.

PyTorch offers a utility called torchrun that provides fault-tolerance and elastic training. When a failure occurs, torchrun logs the errors and attempts to automatically restart all the processes from the last saved “snapshot” of the training job.

The snapshot saves more than just the model state; it can include details about the number of epochs run, optimizer states or any other stateful attribute of the training job necessary for its continuity.

Why use torchrun

torchrun handles the minutiae of distributed training so that you don’t need to. For instance,

  • You don’t need to set environment variables or explicitly pass the rank and world_size; torchrun assigns this along with several other environment variables.

  • No need to call mp.spawn in your script; you only need a generic main() entry point, and launch the script with torchrun. This way the same script can be run in non-distributed as well as single-node and multinode setups.

  • Gracefully restarting training from the last saved training snapshot.

Graceful restarts

For graceful restarts, you should structure your train script like:

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

If a failure occurs, torchrun will terminate all the processes and restart them. Each process entry point first loads and initializes the last saved snapshot, and continues training from there. So at any failure, you only lose the training progress from the last saved snapshot.

In elastic training, whenever there are any membership changes (adding or removing nodes), torchrun will terminate and spawn processes on available devices. Having this structure ensures your training job can continue without manual intervention.

Process group initialization

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

Use torchrun-provided environment variables

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

Saving and loading snapshots

Regularly storing all the relevant information in snapshots allows our training job to seamlessly resume after an interruption.

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

Loading a snapshot in the Trainer constructor

When restarting an interrupted training job, your script will first try to load a snapshot to resume training from.

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

Resuming training

Training can resume from the last epoch run, instead of starting all over from scratch.

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

Running the script

Simply call your entry point function as you would for a non-multiprocessing script; torchrun automatically spawns the processes.

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources