Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate distributed inference with chat/server #1381

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7b4d5c5
Integrate distributed inference without introducing abstraction
mreso Nov 16, 2024
e7670c3
Cleanup old distributed inference integration
mreso Nov 16, 2024
08a8e03
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Nov 16, 2024
d5bca9b
Read distribution from model_config
mreso Nov 18, 2024
76895cc
Declare distribution_path if args.model is not given
mreso Nov 18, 2024
3ef1296
Address some nits from PR review
mreso Nov 19, 2024
04cdfd0
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Dec 2, 2024
924a096
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Dec 4, 2024
99c33e8
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 11, 2024
773894f
Merge remote-tracking branch 'origin/main' into refactor/distributed_…
mreso Dec 16, 2024
7cb98c9
Added comment on model size all reduce + type hint
mreso Dec 16, 2024
10fb55a
Apply suggestions from code review
mreso Dec 16, 2024
28d7836
Make sure speculative decoding is disable for pp >1 and remark this i…
mreso Dec 17, 2024
68eec0b
Refactor conditions in pp
mreso Dec 17, 2024
3ad31e8
Rename and alter signature of setup_env to reflect that it also runs …
mreso Dec 17, 2024
e07b03d
Rename setup_env in server + fix condition
mreso Dec 17, 2024
daf902c
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 19, 2024
db5fd1b
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 19, 2024
7ac16f9
Update generate.py
Jack-Khuu Dec 19, 2024
7650153
Add default value to add_generation_prompt to preserve bc
mreso Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
import torch
import torch._dynamo.config
import torch._inductor.config
import torch.nn as nn
import torch.distributed as dist

from torchchat.model import Model, ModelArgs, ModelType
from torchchat.distributed.utils import(
Color as color,
CUDATrackTime,
init_distributed,
GPUMemoryMonitor,
)
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand All @@ -28,6 +35,7 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model


from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
Expand Down Expand Up @@ -56,6 +64,7 @@ class BuilderArgs:
pp: int = 1
tp: int = 1
chpt_from: str = "hf"
distribution_path: Optional[str] = None
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False
Expand Down Expand Up @@ -107,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":

checkpoint_path = args.checkpoint_path
params_table = args.params_table
distribution_path = None
if args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)

Expand All @@ -121,6 +131,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
model_config.transformer_params_key or model_config.name.split("/")[-1]
)

distribution_path = model_config.distribution_path

dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)
Expand Down Expand Up @@ -186,6 +198,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
pp=pp,
tp=tp,
chpt_from=chpt_from,
distribution_path=distribution_path,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
Expand Down Expand Up @@ -598,6 +611,100 @@ def do_nothing(max_batch_size, max_seq_length):
model = PTEModel(config, builder_args.pte_path)
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
elif builder_args.distributed:
pp_degree = builder_args.pp
tp_degree = builder_args.tp

init_distributed()
rank = dist.get_rank()
torch.cuda.set_device(rank % torch.cuda.device_count())

logger = SingletonLogger.get_logger()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

# Model-level config
if builder_args.params_table:
model_config = ModelArgs.from_table(builder_args.params_table)
else:
raise NotImplementedError()
# Transformer-level config
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

#TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import (
load_model_weights,
)

# Validate pipeline degree
assert config.n_layers % pp_degree == 0

# Create device mesh
device_mesh = dist.init_device_mesh(
"cuda",
(pp_degree, tp_degree),
mesh_dim_names=("pp", "tp")
)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")

pp_rank = pp_mesh.get_local_rank()
logger.info(f"{pp_degree=}, {tp_degree=}")

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Fill in PP configs
config.stage_idx = pp_rank
config.n_stages = pp_degree

with torch.device("meta"):
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
# (Surprisingly, this works even though model is on meta device and mesh is of
# cuda devices)
model.distribute(tp_mesh)
if rank == 0:
logger.info(f"Model: {model}")

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Setup KV caches (after model distribution)
# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
seqlen_prefill=1024
with device:
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)

# info on stage size and params
# stage_size = get_module_size(model)
# stage_size_formatted = bytes_to_readable(stage_size)
# stage_num_params = get_num_params(model)
# logger.info(
# f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
# )
model.eval()

model.text_transformer_args = None
model.config.model_type = model_config.model_type
model.device_mesh = device_mesh
else:
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
Expand Down
32 changes: 32 additions & 0 deletions torchchat/distributed/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed._tensor import DTensor
from torchchat.distributed.dtensor_utils import convert_to_dtensor
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
from torchchat.model import ModelArgs


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
Expand Down Expand Up @@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
# Fill state dict into stage module
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")


def load_model_weights(
stage_module: torch.nn.Module,
distribution: str,
device: torch.device,
model_config: ModelArgs,
chpt_from: str,
):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.

Args:
stage_module (torch.nn.Module): The model stage to load the weights into.
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
"""
if chpt_from == "hf":
# This format stands for: index file + multiple binary files
load_weights_from_hf_format(stage_module, distribution, device, model_config)
elif chpt_from == "torchchat":
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(
stage_module, distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
Loading
Loading