Skip to content

Commit

Permalink
Merge pull request jy0205#109 from jy0205/revert-108-use_mps_on_apple…
Browse files Browse the repository at this point in the history
…_silicon

Revert "Use MPS backend on Apple Silicon devices if it's available."
  • Loading branch information
feifeiobama authored Oct 16, 2024
2 parents 701d2d5 + 5261502 commit 9aa02f8
Showing 6 changed files with 20 additions and 40 deletions.
18 changes: 4 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
@@ -9,9 +9,6 @@
from huggingface_hub import snapshot_download
import threading

# Disabling parallelism to avoid deadlocks.
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Global model cache
model_cache = {}

@@ -20,8 +17,7 @@

# Configuration
model_repo = "rain1011/pyramid-flow-sd3" # Replace with the actual model repository on Hugging Face
model_dtype = "bf16" if torch.cuda.is_available() else "fp32" # Support bf16 and fp32

model_dtype = "bf16" # Support bf16 and fp32
variants = {
'high': 'diffusion_transformer_768p', # For high-resolution version
'low': 'diffusion_transformer_384p' # For low-resolution version
@@ -31,7 +27,7 @@
height_high = 768
width_low = 640
height_low = 384
cpu_offloading = torch.cuda.is_available() # enable cpu_offloading by default
cpu_offloading = True # enable cpu_offloading by default

# Get the current working directory and create a folder to store the model
current_directory = os.getcwd()
@@ -90,8 +86,6 @@ def initialize_model(variant):

if model_dtype == "bf16":
torch_dtype_selected = torch.bfloat16
if model_dtype == "fp16":
torch_dtype_selected = torch.float16
else:
torch_dtype_selected = torch.float32

@@ -116,10 +110,6 @@ def initialize_model(variant):
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
elif torch.mps.is_available():
model.vae.to("mps")
model.dit.to("mps")
model.text_encoder.to("mps")
else:
print("[WARNING] CUDA is not available. Proceeding without GPU.")

@@ -180,7 +170,7 @@ def progress_callback(i, m):

try:
print("[INFO] Starting text-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', enabled=torch.cuda.is_available(), dtype=torch_dtype_selected):
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
@@ -236,7 +226,7 @@ def progress_callback(i, m):

try:
print("[INFO] Starting image-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', enabled=torch.cuda.is_available(), dtype=torch_dtype_selected):
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
6 changes: 3 additions & 3 deletions diffusion_schedulers/scheduling_flow_matching.py
Original file line number Diff line number Diff line change
@@ -176,7 +176,7 @@ def set_begin_index(self, begin_index: int = 0):
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps

def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
"""
Setting the timesteps and sigmas for each stage
"""
@@ -191,7 +191,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio
timesteps = np.linspace(
timestep_max, timestep_min, num_inference_steps,
)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device)

stage_sigmas = self.sigmas_per_stage[stage_index]
sigma_max = stage_sigmas[0].item()
@@ -200,7 +200,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio
ratios = np.linspace(
sigma_max, sigma_min, num_inference_steps
)
sigmas = torch.from_numpy(ratios).to(device=device, dtype=dtype)
sigmas = torch.from_numpy(ratios).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

self._step_index = None
2 changes: 1 addition & 1 deletion pyramid_dit/modeling_pyramid_mmdit.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
25 changes: 9 additions & 16 deletions pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -147,9 +147,8 @@ def _enable_sequential_cpu_offload(self, model):
cpu_offload(model, device, offload_buffers=offload_buffers)

def enable_sequential_cpu_offload(self):
if torch.cuda.is_available():
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)

def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
@@ -249,7 +248,7 @@ def generate_one_unit(
intermed_latents = []

for i_s in range(len(stages)):
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device, dtype=dtype)
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
timesteps = self.scheduler.timesteps

if i_s > 0:
@@ -337,7 +336,7 @@ def generate_i2v(
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.mps.is_available() else torch.device("cpu")
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
@@ -453,11 +452,8 @@ def generate_i2v(

for unit_index in tqdm(range(1, num_units)):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()

torch.cuda.empty_cache()

if callback:
callback(unit_index, num_units)

@@ -557,7 +553,7 @@ def generate(
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.mps.is_available() else torch.device("cpu")
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
@@ -654,11 +650,8 @@ def generate(

for unit_index in tqdm(range(num_units)):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()

torch.cuda.empty_cache()

if callback:
callback(unit_index, num_units)

7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
wheel
torch
torchvision
torch==2.1.2
torchvision==0.16.2
transformers==4.39.3
accelerate==0.30.0
diffusers>=0.30.1
numpy==1.26.4
numpy==1.24.4
einops
ftfy
ipython
2 changes: 0 additions & 2 deletions video_vae/modeling_causal_vae.py
Original file line number Diff line number Diff line change
@@ -361,8 +361,6 @@ def chunk_decode(self, z: torch.FloatTensor, window_size=2):

dec_list = []
for idx, frames in enumerate(frame_list):
if torch.mps.is_available():
torch.mps.empty_cache()
if idx == 0:
z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)

0 comments on commit 9aa02f8

Please sign in to comment.