Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail to offload FSDP model weights and optimizer states without using CPUOffload(offload_params=True) #130530

Open
PeterSH6 opened this issue Jul 11, 2024 · 3 comments
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@PeterSH6
Copy link

PeterSH6 commented Jul 11, 2024

🚀 The feature, motivation and pitch

Hi Pytorch maintainers,

I am currently engaged in training multiple large language models (LLMs) sequentially on a single GPU machine, utilizing FullShardDataParallel (FSDP) for each model. A significant challenge we face is managing the storage demands for multiple LLMs, including their optimizer states, gradients, and activations.

We notice that FSDP supports offloading the model parameters and optimizer states during training cpu_offload=CPUOffload(offload_params=True). However, this feature will offload the model parameters and optimizer states in CPU and perform optimizer step in CPU, which will affect the training throughput.

In our scenario, the models are computed one by one so we manage to offload the other models when one LLM is performing computation on GPU.
However, we fail to offload the _fsdp_wrapped_module into the CPU with the following code. (We offload the FSDP model by offloading the parameters in named_paramteres()). We found identical GPU memory usage before/after the offload operation, indicating no effective offloading.

It appears that there might be persistent references to these parameters, causing p.data.to('cpu') to merely copy the data in CPU and preventing PyTorch's garbage collector from freeing the original GPU storage.

Could you provide guidance on how to properly offload these FSDP-wrapped parameters to the CPU to alleviate GPU memory constraints effectively? Any insights or updates to facilitate this process would greatly enhance our training capabilities and efficiency.

Thank you for your attention to this feature/issue!

local_rank, rank, world_size = initialize_global_process_group()

with torch.device("cuda"):
  actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)
  actor_model.to(torch.bfloat16)

mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
                              reduce_dtype=torch.float32,
                              buffer_dtype=torch.float32)
fsdp_model = FSDP(actor_model,
                use_orig_params=True,
                auto_wrap_policy=None,
                device_id=torch.cuda.current_device(),
                sharding_strategy=ShardingStrategy.FULL_SHARD,
                mixed_precision=mixed_precision,
                cpu_offload=CPUOffload(offload_params=False),
                sync_module_states=False,
                device_mesh=device_mesh))

input_ids = torch.randint(low=0, high=actor_model_config.vocab_size, size=(2, 1024))
output = fsdp_model(input_ids=input_ids)
output.logits.mean().backward()

FSDP.set_state_dict_type(fsdp_model,
                      state_dict_type=StateDictType.SHARDED_STATE_DICT,
                      state_dict_config=ShardedStateDictConfig())

if torch.distributed.get_rank() == 0:
    print(f'before model to cpu memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')

# offload
for n, p in fsdp_model.named_parameters():
    print(f'rank: {rank}, name: {n} weight shape: {p.data.shape}')
    p.data = p.data.to('cpu')
    
torch.cuda.empty_cache()

if torch.distributed.get_rank() == 0:
    print(f'after model to cpu memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @anijain2305 @chauhang @penguinwu

@awgu
Copy link
Contributor

awgu commented Jul 11, 2024

Ugh, this is tricky. I think the problem might be because you are using use_orig_params=True.

With use_orig_params=True, each parameter returned from named_parameters() is a view into the underlying FlatParameter (which would otherwise be returned from named_parameters() for use_orig_params=False). This means that when you call .to('cpu') on each parameter, each parameter view into the FlatParameter is moved to CPU, but the underlying FlatParameter is on GPU.

FSDP takes care to preserve the invariant that the parameters are always views into the FlatParameter, so if you break this invariant, it may be tricky. There should be some logic in the pre-forward to restore the invariant by copying parameters back into their FlatParameter.

First, could you try to run with use_orig_params=False?

@malfet malfet added oncall: pt2 oncall: distributed Add this issue/PR to distributed oncall triage queue and removed oncall: pt2 labels Jul 11, 2024
@LucasLLC LucasLLC added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Jul 15, 2024
@rbao2018
Copy link

Ugh, this is tricky. I think the problem might be because you are using use_orig_params=True.

With use_orig_params=True, each parameter returned from named_parameters() is a view into the underlying FlatParameter (which would otherwise be returned from named_parameters() for use_orig_params=False). This means that when you call .to('cpu') on each parameter, each parameter view into the FlatParameter is moved to CPU, but the underlying FlatParameter is on GPU.

FSDP takes care to preserve the invariant that the parameters are always views into the FlatParameter, so if you break this invariant, it may be tricky. There should be some logic in the pre-forward to restore the invariant by copying parameters back into their FlatParameter.

First, could you try to run with use_orig_params=False?

I tested use_orig_params=False with the following setup:

  • Model: Llama3.1-8B
  • GPUs: 8 x A100-80GB

Results:
Before model to CPU:

  • Memory allocated: 8.461610496GB
  • Memory reserved: 37.589352448GB

After model to CPU:

  • Memory allocated: 8.066843136GB
  • Memory reserved: 13.501464576GB

This method shows some effect, but PyTorch still retains some parameters. Is it possible to further hack the code to unload these remaining GPU parameters?

@awgu
Copy link
Contributor

awgu commented Oct 23, 2024

By the way, when you call torch.cuda.empty_cache(), if any underlying caching allocator segment still has some active allocation, then that entire segment cannot be freed. However, when you do more allocations later, you will still be able to use the unused parts of those segments, so it is not like the extra 5.5 GB you see as (reserved - allocated) is lost.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants