Fully Sharded Data Parallelism (FSDP)

Fully Sharded Data Parallelism (FSDP)

In this blog we will explore Fully Sharded Data Parallelism (FSDP), which is a technique that allows for the training of large Neural Network models in a distributed manner efficiently. We’ll examine FSDP from a bird’s eye view and shed light on the underlying mechanisms.


When choosing a distributed learning approach, it’s important to be aware of the advantages and disadvantages of each strategy in order to implement one that matches the target use-case at hand.

For large language models (LLMs) specifically, the large number of parameters entails grappling with significant GPU memory requirements, and FSDP emerges as a high-performance solution in this context, as it is able to effectively address the significant GPU memory demands. By leveraging multiple GPUs, FSDP offers a practical trade-off, optimizing GPU communication to minimize memory usage.

On the other hand, for Computer Vision models that typically can fit on a single GPU, Distributed Data Parallelism (DDP) often proves more efficient, allowing for the entire model to be run while avoiding the GPU communication overhead associated with FSDP.

Why Can’t We Run the Model Sequentially?

Suppose we find ourselves wanting to run a large model that won't fit on a single GPU. The “naive” approach to running the model involves splitting the model in a so-called “vertical” way by allocating different model layers to different GPUs. Each GPU handles a specific set of layers, and the whole model can be run sequentially. Problem solved, right? Unfortunately, that’s not the case—this naive method has significant limitations, and we’ll show that FSDP is an answer to overcoming them.

Let's consider a scenario with n GPUs and a model size denoted as S, in which the model cannot fit on one GPU but can fit on all the GPUs combined. After splitting the model “vertically,” each GPU has a slice of the model of size S/n. The time it takes for a single GPU to run the forward pass on its model slice is represented as T_forward, and the backward run time is denoted as T_backward.

This approach allows for a complete forward pass in just T_forward*n, reflecting the best case of FSDP. Moreover, only the activations are passed, not the model weights, gradients or the optimizer state.

This strategy, at least at a first glance, seems to be a valid way for model training with multiple GPUs, and given the straightforward nature of the approach, one might question the necessity of doing anything more complicated—it seems we have no need for extensive GPU communication, splitting the dataset between the GPUs, or preforming any complicated GPU communication like the operations used in FSDP.

The issue emerges when the first GPU completes its forward pass and intends to run forward pass on the next batch. We cannot begin the next forward pass for the subsequent batches before updating the initial weights of the first GPU model weights—so the first GPU must wait for all downstream GPUs to complete the forward pass and propagate their respective gradients backwards before its weights can be updated, resulting in significant idle time. Specifically, the first GPU waits for T_forward*(n-1) + T_backward*(n-1) before applying backward propagation for the first batch.

We’ve split our model among our GPUs, but most of the time, each GPU is just sitting around doing nothing, waiting for the data to propagate elsewhere in the model!

FSDP offers a solution to this problem, providing the ability to fully make use of all the GPUs on large models without significant idle GPU time.

Laying the Groundwork

There are two separate actions that take place to setup the FSDP process:

Model Partition (Vertical Assignment)

Just as in the naive solution above, in vertical splitting the model layers are organized into “units.” For example, in a model with 9 convolution layers, each unit is responsible for a specific range of layers. The first GPU unit might manage layers 1 to 3, the second unit, layers 4 to 6, and the last unit might oversee layers 7 to 9.

Sharding (Horizontal Splitting)

“Horizontal splitting” refers to the splitting of the model parameters within each layer and storing them on individual GPUs; this process is also commonly called sharding. For example, for 3 fully connected layers, instead of storing each fully connected layer on one GPU, each GPU holds one third of the fully connected layer entities (that is, the parameters, gradients, and optimizer states).

Throughout the training process, collaboration among GPUs occurs as they share necessary shards—in doing so, we store redundant parameters and incur communication overhead between the GPUs, but in doing so, we’re able to keep all the GPUs busy at all times.

All GPUs will run all the units one by one in parallel during forward and backward steps by gathering the necessary shards of model parameters and other entities from other GPUs.

Sharded Entities

In PyTorch's FSDP, there are multiple configuration settings known as “sharding strategies” that govern the distribution and management of model shards. This blog post will delve into the intricacies of the FULL_SHARD sharding strategy, which is the most memory efficient but communication-intensive strategy.

Under the FULL_SHARD strategy, the following key entities are subjected to sharding:

  1. Model Parameters (MP): These include the core components of the model such as weights, biases, and buffers, along with additional parameters specific to the model architecture.
  2. Gradients (GRD): These are the gradients computed during the backward pass, which allow for updating the model's weights.
  3. Optimizer State (OS): These data include all the pieces necessary to perform some flavor of gradient descent during training. For example, when using an Adam optimizer, this entity encompasses the stored gradients, momentum, and variance. Typically, the data are retained in 32-bit floating-point format (FP32), ensuring precision during the optimization process.

Step-by-Step FSDP Breakdown

In our example above, we assumed we have three available GPUs, and we will once again assume that the model cannot fit on one GPU but can fit on all of them combined. The neural network that we will use for our toy example has 9 layers and each unit will be assigned 3 layers. The extension to the more common case where the model can’t fit on all GPUs at once is straightforward once the machinery of this toy example has been explained.

We will define the following terms:

  • shard: a single chunk of the split entities that is attached to a specific GPU (contains a small portion of every entity across the entire model)
  • Activation (ACT): the activation calculated in each GPU separately during forward pass
  • unit: a part of the model with its assigned layers created by vertical model partition
  • MEM_total: the memory size of all the parameters to be stored, i.e., MP + GRD + OS


The following initial steps set the stage for FSDP:

  • Split dataset: Split the dataset into three subsets and assign each of them to a specific GPU to be processed independently.
  • Assign units: Assign specific layers to each unit that will be in charge of managing them during the training process.
  • Shard the model: Divide each entity (MP, OS, GRD) into 3 shards and allocate them to the GPUs so that each GPU has to hold only MEM_total/3 in its memory.

💡 As mentioned, the sharding of the model entities is 'horizontal,' meaning that each shard includes model parameters from every layer, as shown in the following diagram:

Forward Pass

It's important to note that in PyTorch's FSDP with the 'FULL_SHARD' strategy, both the gradients and optimizer state are shown as they are before the first backward and optimizer steps. At this stage, these entities are not yet calculated; they are placeholders. This implies that they will only contain actual, calculated information after being updated during the backward pass and the optimizer step.

The forward pass consists of the following steps:

  • Broadcast model parameters: All GPUs will gather the model parameters of the first unit (MP 1) so they can run the first forward step.

💡 In the diagrams, opaque colors indicate that the shard is “owned” by the GPU and will persist throughout the entire training process. Conversely, a low-opacity filling denotes sharded entities not attached to the GPU, and they will be discarded after usage during resharding.

  • Forward pass unit 1: Each GPU will run forward pass on unit 1 on its respective batch using the complete MP 1 that each GPU gathered from all the other GPUs. Since each GPU has different input batch the ACT that each one will calculate will be different even though all of them currently hold the same model parameters, MP 1. In some FSDP configurations, the forward pass can be preformed in parallel by loading the next MP (in this case MP 2), which further accelerates training. However, this also increases GPU memory usage since the GPU must hold the MP from two different units at the same time.
  • Save activations: After we calculate the ACT they will be retained in each GPU for later use in gradient computation during the backward pass.
  • Reshard MP 1: Delete only the broadcasted (low opacity) MP 1 from each GPU in order to free up GPU memory—note that each GPU still holds on to the shard that was assigned to it.
  • Repeat for all the other units: Repeat the process for subsequent units 2 and 3, broadcast, run the forward pass, and reshard the MP while holding on to the ACT until unit 3 forward pass is done. Doing so will give us the ACT for the entire model.
  • Compute loss: For each GPU, compute the loss of its respective batch using the loss function.

Backward Pass

The backward pass consists of the following steps:

  • Broadcast model parameters: Gather MP for the current unit—we already have the MP at hand for the backward pass on unit 3, since we just broadcasted them to all GPUs for the forward pass. Therefore this step can be skipped for the unit 3 but is required for the backward pass of unit 2 and 1.
  • Propagate backward: Initiate backward propagation and update GRD using the ACT and MP on all GPUs for unit 3. As mentioned at start of the Forward Pass section, we remark that at this point, the gradients have not yet been calculated and are only placeholders that do not contain any actual information. In the next step, they will be individually calculated for each GPU.
  • Accumulate gradients: Take the GRD calculated in each GPU for unit 3, sum them to get the accumulated GRD, then distribute the accumulated GRD across the GPUs. Afterwards, we reshard the broadcasted GRD 3 by removing the broadcasted GRD and replacing the existing shard of GRD in each GPU with the accumulated one (reduce-scatter operation on GRD).
  • Reshard MP and ACT: Remove the broadcasted MP and ACT from all GPUs to free up GPU memory.
  • Repeat for all the other units: Repeat the previous steps, broadcast, execute backward pass to collect GRD, and discard ACT until the completion of backpropagation on units 2 and 1.

Optimizer Step

  • Apply optimizer step: Run the optimizer step, update all MP and optimizer states. This constitutes a complete training step for the entire model on a single batch, achieving our goal of updating the model parameters while operating GPUs in parallel.
  • Next batch: This brings us back to the initial state but with updated MP GRD, and OS. Now, we can repeat all the steps for forward and backward propagation, as well as the optimization step, using the next batch as input until the training is complete.

This concludes the description of the FSDP training process. Overall, we have seen that this process includes somewhat complicated interaction operations between multiple GPUs, but results in minimal GPU idle time. In this way, FSDP makes the most of available computing resources and allows the training of large models in an efficient manner in a server-side training environment.

Looking Ahead

In our next blog post, we will delve into the practical aspects of FSDP including code snippets from PyTorch FSDP, guiding you through the process of training a model using FSDP on GCP. Stay tuned for hands-on insights and implementation details!