Skip to content

Commit

Permalink
Merge pull request #26 from SeanScripts/main
Browse files Browse the repository at this point in the history
Add CPU offloading to I2V
jy0205 authored Oct 11, 2024
2 parents 3a31cca + 3ac1528 commit e1c8a30
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -309,9 +309,18 @@ def generate_i2v(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
save_memory: bool = True,
cpu_offloading: bool = False, # If true, reload device will be cuda.
):
device = self.device
device = self.device if not cpu_offloading else "cuda"
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
if str(self.dit.device) != "cpu":
print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.dit.to("cpu")
if str(self.vae.device) != "cpu":
print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
self.vae.to("cpu")

width = input_image.width
height = input_image.height
@@ -332,8 +341,13 @@ def generate_i2v(
negative_prompt = negative_prompt or ""

# Get the text embeddings
if cpu_offloading:
self.text_encoder.to("cuda")
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
if cpu_offloading:
self.text_encoder.to("cpu")
self.vae.to("cuda")

if use_linear_guidance:
max_guidance_scale = guidance_scale
@@ -385,6 +399,10 @@ def generate_i2v(
generated_latents_list = [input_image_latent] # The generated results
last_generated_latents = input_image_latent

if cpu_offloading:
self.vae.to("cpu")
self.dit.to("cuda")

for unit_index in tqdm(range(1, num_units + 1)):
if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
@@ -443,7 +461,13 @@ def generate_i2v(
if output_type == "latent":
image = generated_latents
else:
if cpu_offloading:
self.dit.to("cpu")
self.vae.to("cuda")
image = self.decode_latent(generated_latents, save_memory=save_memory)
if cpu_offloading:
self.vae.to("cpu")
# not technically necessary, but returns the pipeline to its original state

return image

0 comments on commit e1c8a30

Please sign in to comment.