Train Your Large Model on Multiple GPUs with Fully Sharded Data Parallelism

Some language models are too large to train on a single GPU. In addition to creating the model as a pipeline of stages, as in Pipeline Parallelism, you can split the model across multiple GPUs using Fully Sharded Data Parallelism (FSDP). In this article, you will learn how to use FSDP to split models for training. In particular, you will learn about:

  • The idea of sharding and how FSDP works
  • How to use FSDP in PyTorch

Let’s get started!

Train Your Large Model on Multiple GPUs with Fully Sharded Data Parallelism.
Photo by Ferenc Horvath. Some rights reserved.

Overview

This article is divided into five parts; they are:

  • Introduction to Fully Sharded Data Parallel
  • Preparing Model for FSDP Training
  • Training Loop with FSDP
  • Fine-Tuning FSDP Behavior
  • Checkpointing FSDP Models

Introduction to Fully Sharded Data Parallel

Sharding is a term originally used in database management systems, where it refers to dividing a database into smaller units, called shards, to improve performance. In machine learning, sharding refers to dividing model parameters across multiple devices. Unlike pipeline parallelism, the shards contain only a portion of any complete operation. For example, the nn.Linear module is essentially a matrix multiplication. A sharded version of it contains only a portion of the matrix. When a sharded module needs to process data, you must gather the shards to create a complete matrix temporarily and perform the operation. Afterwards, this matrix is discarded to reclaim memory.

When you use FSDP, all model parameters are sharded, and each process holds exactly one shard. Unlike data parallelism, where each GPU has a full copy of the model and only data and gradient updates are synchronized across GPUs, FSDP does not keep a full copy of the model on each GPU; instead, both the model and the data are synchronized on every step. Therefore, FSDP incurs higher communication overhead in exchange for lower memory usage.

FSDP requires processes to exchange data to unshard the model.

The workflow of FSDP is as follows:

There will be multiple processes running together, possibly on multiple machines across a network. Each process (equivalently, each GPU) holds only one shard of the model. When the model is sharded, each module’s weights are stored as a DTensor (distributed tensor, sharded across multiple GPUs) rather than a plain Tensor. Therefore, no process can run any module independently. Before each operation, FSDP issues an all-gather request to enable all processes to exchange a module’s shards with one another. This creates a temporary unsharded module, and each process runs the forward pass on this module with its micro-batch of data. Afterward, the unsharded module is discarded as the processes move on to the next module in the model.

Similar operations happen in the backward pass. Each module must be unsharded when FSDP issues an all-gather request to it. Then the backward pass computes gradients from the forward pass results. Note that each process operates on a different micro-batch of data, so the gradients computed by each process are different. Therefore, FSDP issues a reduce-scatter request, causing all processes to exchange gradients so that the final batch-wide gradient is averaged. This final gradient is then used to update the model parameters on every shard.

As shown in the figure above, FSDP requires more communication and has a more complex workflow than plain data parallelism. Since the model is distributed across multiple GPUs, you do not need as much VRAM to host a very large model. This is the motivation for using FSDP for training.

Comparing DP (left) and FSDP (right). Illustration adapted from the blog post by Ott et al.

To improve FSDP’s efficiency, PyTorch uses prefetching to overlap communication and computation. While your GPU computes the first module, the processes exchange shards from the second module, so the second module becomes available once the first is complete. This keeps both the GPU and the network busy, reducing per-step latency. Some tuning in FSDP can help you maximize such overlap and improve training throughput, often at the cost of higher memory usage.

Preparing Model for FSDP Training

When you need FSDP, usually it means your model is too large to fit on a single GPU. One way to enable such a large model is to train it on a fake device meta, then shard it and distribute the shards across multiple GPUs.

In PyTorch, you need to use the torchrun command to launch an FSDP training script with multiple processes. Under torchrun, each process will see the world size (total number of processes), its rank (the index of the current process), and its local rank (the index of the GPU device on the current machine). In the script, you need to initialize this as a process group:

Next, you should create the model and then shard it. The code below is based on the model architecture described in the previous post:

In PyTorch, you use the fully_shard() function to create a sharded model. This function replaces parameters of type Tensor with DTensor in-place. It also modifies the model to perform the all-gather operation before the actual computation.

You should notice that in the above, fully_shard() is not only called on model, but also on model.base_model as well as each transformer block in the base model. This needs careful consideration.

Usually, you do not want to shard only the top-level model, but also its submodules. When you do so, you must apply fully_shard() from bottom up, with the top-level model being sharded last. Each sharded module will be one unit of all-gather. In the design shown above, when you pass a tensor to model, the top-level model components will be unsharded, except for those that were sharded separately. Since it is a decoder-only transformer model, the input should be processed by the base model first, then the prediction head in the top model. FSDP will unshard the base model, except for each repeating transformer block. This includes the input embedding layer, which is the first operation applied to the input tensor.

After the embedding layer, the input tensor should be processed by a sequence of transformer blocks. Each block is sharded separately, so all-gather is triggered for each block. The block transforms the input and passes it on to the next transformer block. After the last transformer block, the RMS norm layer in the base model, which is already unsharded, processes the output before returning to the top model for the prediction.

This is why you do not want to shard the top-level model: if you do, the all-gather operation will create a full model on each GPU, violating the assumption that each GPU has insufficient memory to support the full model. In that case, you should use plain data parallelism rather than FSDP.

In this design, each GPU requires one complete transformer block plus the other modules in the top and base models, such as the embedding layer, the final RMS norm layer in the base model, and the prediction head in the top model. You can revise this design (for example, by further sharding model.base_model.embed_tokens and breaking down each transformer block into attention and feed-forward sublayers) to further reduce the memory requirement.

After you have the sharded model, you can transfer it from a meta device to your local GPU with model.to_empty(device=device). You also need to reset the weights of the newly created model (unless you want to initialize them from a checkpoint). You can borrow the function reset_all_weights() from the previous post to reset the weights. Here is another way that uses model.reset_parameters(). This requires you to implement the corresponding member function in each module:

You know the model is sharded if it is an instance of FSDPModule. Subsequently, you can create the optimizer and other training components as usual. The PyTorch optimizer supports updating DTensor objects the same way as plain Tensor objects.

Training Loop with FSDP

Using FSDP is straightforward. Virtually nothing needs to be changed in the training loop:

The only change you can observe is the use of model.unshard() to trigger the all-gather before the forward pass, but this is optional. Even if you do not call it, model(input_ids, attn_mask) will still trigger the all-gather operation internally. This line starts the all-gather before the input tensor is prepared for the forward pass.

However, FSDP is partially a data parallelism technique. As with distributed data parallelism, you should use a sampler with your data loader so that each rank in the process group processes a different micro-batch. This works because each process receives a complete module of the model via all-gather, so each process can use that module to process a different micro-batch of data. In essence, FSDP exchanges both the model and the training data, going one step further than data parallelism. Below is how you should set up your data loader:

This is the same as how you set up the data loader for distributed data parallel in the previous article.

Fine-Tuning FSDP Behavior

The above is all you need to run FSDP training. However, you can introduce variations to fine-tune FSDP’s behavior.

Using torch.compile()

If your model can be compiled, you can also compile an FSDP model. However, you need to compile it after sharding the model, so the compiled model can reference the distributed tensors rather than plain tensors.

Arguments to fully_shard()

Recall that you can use torch.autocast() to run mixed precision training. You can also enable mixed-precision training in FSDP, but you must apply it when sharding the model. The change needed is particularly simple:

When you shard the model, you can specify the argument mp_policy to describe exactly how the mixed precision training should be performed. In the example above, you keep the model parameters in bfloat16, but use float32 for gradients (during scatter-reduce). You can also specify output_dtype and cast_forward_inputs to define the data types of the forward pass inputs and outputs. Note that since fully_shard() is applied to each module, you are free to use different mixed precision policies for different modules.

Of course, PyTorch still allows you to use torch.set_default_dtype(torch.bfloat16) to change the default data type for the entire model. This changes the default data type for all DTensor objects created.

In FSDP, you need an all-gather step before the actual forward or backward computation. Before all-gather, you do not have a complete parameter for the operation. Since inter-process communication is slow and a lot of data needs to be moved to the GPU anyway, you can apply CPU offloading to keep your sharded model in CPU memory when it is not in use. This means:

Typically, using CPU offloading makes the training loop noticeably slower. If you use CPU offloading, you should consider changing the training loop such that the optimizer zeros out gradient tensors instead of setting the gradients to None:

This is because CPU memory is usually more abundant than GPU memory, and you can afford to keep the allocated gradient tensors in memory to avoid the overhead of re-allocating them.

The third argument you can add to fully_shard() is reshard_after_forward=True. By default (reshard_after_forward=None), FSDP will keep the unsharded model in the memory of the root module after the forward pass, so the backward pass does not need to call all-gather again. Non-root modules will always discard the unsharded tensors, unless you set reshard_after_forward=False.

Usually, you do not want to change this setting, since this likely means you need to run all-gather immediately after discarding the unsharded tensors. But understanding how this parameter works lets you reconsider your model design: In the implementation of LlamaForPretraining above, the root module contains only the prediction head. But if you move the embedding layer from the base model LlamaModel to the root model, you will keep the embedding layer (which is usually large) in memory for a long time. This is the model engineering you can consider when applying FSDP.

Gradient Checkpointing

FSDP has a lower memory requirement than plain data parallelism. If you want to reduce memory usage further, you can use gradient checkpointing with FSDP. Unlike the plain model, you do not use torch.utils.checkpoint.checkpoint() to wrap the part that requires gradient checkpointing. Instead, you set a policy and apply it to the sharded model:

The wrap_policy is a helper function that checks whether the module belongs to one of the listed classes. If so, gradient checkpointing will be applied to it, so its internal activations are discarded after the forward pass and recomputed during the backward pass. The function apply_activation_checkpointing() recursively scans the module and applies gradient checkpointing to its submodules.

As a reminder, gradient checkpointing is a technique that trades time for memory during training. You save memory by discarding intermediate activations, but the backward pass is slower due to recomputation.

All-Gather Prefetching

FSDP implements a similar efficiency optimization to pipeline parallelism: it issues an all-gather request to the next module while the current module is processing data. This is called prefetching, and it deliberately overlaps communication and computation to reduce the latency of each training step.

You can indeed control how the prefetching is performed. Below is an example:

By default, FSDP determines the next module and prefetches it. The code above causes FSDP to prefetch not the next item but two items ahead. The modules list enumerates the sharded modules in the model in their execution order. Then, for each module, you set the forward prefetch to two subsequent modules and the backward prefetch to two preceding modules.

Note that FSDP will not check if you specify them in the correct execution order. If you prefetch the wrong module, your training performance will deteriorate. But you also must not specify a module that is not sharded (such as model.lm_head in the example above) as FSDP will not be able to issue all-gather requests for it.

Checkpointing FSDP Models

FSDP model is still a PyTorch model, but with the model weights replaced by DTensor objects. If you want to, you can still manipulate the DTensor objects like a Tensor object, as the optimizer would do in your training loop. You can also check the DTensor objects to see what is in each shard:

You can use this property to save and load a sharded model. However, you must ensure that only one process is saving the model so that you do not overwrite the file on disk:

Indeed, there is an easier method: The distributed checkpointing API, as you have already seen in the previous article:

The cpu_offload option must be removed if you do not use CPU offloading.

These two functions are supposed to be called by all processes together. Each process will save its own sharded model and optimizer state to a different file, all under the same directory as the checkpoint_id you specified. Do not attempt to read them with torch.load() since these files are in a different format. However, you can still use the same load_checkpoint() function above on an unsharded model in a plain Python script. Usually, after training is completed, you can recreate the model file from sharded checkpoints:

For completeness, below is the full script that you can run FSDP training:

To run this code, you need to run it with the torchrun command, such as torchrun --standalone --nproc_per_node=4 fsdp_training.py.

This code incorporates all elements discussed in this article. It may not be the most efficient implementation. You should read and modify it to suit your needs.

Further Readings

Below are some resources that you may find useful:

Summary

In this article, you learned about Fully Sharded Data Parallelism (FSDP) and how to use it in PyTorch. Specifically, you learned:

  • FSDP is a data parallelism technique that shards the model across multiple GPUs.
  • FSDP requires more communication and has a more complex workflow than plain data parallelism.
  • FSDP can be used to train very large models with fewer GPUs. You can also apply mixed-precision training and other techniques to trade off memory and compute performance.

No comments yet.

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.