Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate distributed inference with chat/server #1381

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7b4d5c5
Integrate distributed inference without introducing abstraction
mreso Nov 16, 2024
e7670c3
Cleanup old distributed inference integration
mreso Nov 16, 2024
08a8e03
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Nov 16, 2024
d5bca9b
Read distribution from model_config
mreso Nov 18, 2024
76895cc
Declare distribution_path if args.model is not given
mreso Nov 18, 2024
3ef1296
Address some nits from PR review
mreso Nov 19, 2024
04cdfd0
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Dec 2, 2024
924a096
Merge branch 'main' into refactor/distributed_inference_without_abstr…
mreso Dec 4, 2024
99c33e8
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 11, 2024
773894f
Merge remote-tracking branch 'origin/main' into refactor/distributed_…
mreso Dec 16, 2024
7cb98c9
Added comment on model size all reduce + type hint
mreso Dec 16, 2024
10fb55a
Apply suggestions from code review
mreso Dec 16, 2024
28d7836
Make sure speculative decoding is disable for pp >1 and remark this i…
mreso Dec 17, 2024
68eec0b
Refactor conditions in pp
mreso Dec 17, 2024
3ad31e8
Rename and alter signature of setup_env to reflect that it also runs …
mreso Dec 17, 2024
e07b03d
Rename setup_env in server + fix condition
mreso Dec 17, 2024
daf902c
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 19, 2024
db5fd1b
Merge branch 'main' into refactor/distributed_inference_without_abstr…
Jack-Khuu Dec 19, 2024
7ac16f9
Update generate.py
Jack-Khuu Dec 19, 2024
7650153
Add default value to add_generation_prompt to preserve bc
mreso Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
import torch
import torch._dynamo.config
import torch._inductor.config
import torch.nn as nn
import torch.distributed as dist

from torchchat.model import Model, ModelArgs, ModelType
from torchchat.distributed.utils import(
Color as color,
CUDATrackTime,
init_distributed,
GPUMemoryMonitor,
)
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand All @@ -28,6 +35,7 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model


from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
Expand Down Expand Up @@ -598,6 +606,117 @@ def do_nothing(max_batch_size, max_seq_length):
model = PTEModel(config, builder_args.pte_path)
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
elif builder_args.distributed:
# Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B".
#TODO This is a hacky way to please the distributed loading api and needs to be replaced
NAME_TO_DISTRIBUTION = {
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct",
"Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct",

}
Copy link
Contributor

@mikekgfb mikekgfb Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the gap to using the models described in model_config/models.json. (as implied by TODO comment)

Definitely should not be part of the present PR, but I think as a north star, it would be desirable to grab the same models (and download and mgmt infra etc) for non-distributed and distributed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment @mikekgfb. The gap wasn't that big, the args.model just wasn't accessible at that point and I wanted to take a deeper look to fix it right. Removed the cruel hack and I now save the distribution_path when creating the builder_arg. Still not sure if this is the intended way of locating the checkpoint though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you're trying to do. If it's just that the dict/json data structure describing the mapping isn't in scope, maybe what you want is some methods that give you the relevant info?

Also, these seem to be mapping of short names to HF network paths - should we not have a way to [ick them from the local filesystem (since the torchchat cli already manages download and all that). Oh, and if the answer is "we have bigger fish to fry, hooking this up is not highest priority" I will wholeheartedly agree. This is more about understanding the context of this PR.

Where I'm lacking the context is how you go from all the weights being available locally on a node to reading those weights on another node? And maybe that's why you prefer to straight up pick the files from HF? (Although local distribution from an already downloaded set of weights probs has higher bandwidth?) Again, there's much bigger fish to fry, and I think this PR is a good step in the direction of frying those fish ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, let me try to fill in some context for this PR. Previously, the distributed inference solution lived completely separate in its own script (within the torchchat repo but) separate from torchchat.py. Distributed provides its own utils to load either hf or torchchat weights (where the torchchat part is currently broken IIRC). In a previous PR (#1327), I enabled the usage of torchchat.py generate with a distributed model. This PR only progresses the integration into the cli by enabling chat/server but stops short from replacing the weight loading mechanics which are still custom to distributed.

So, yes, for this PR I was only looking for a quick and dirty way to map arg.model_name -> "huggingface distribution str" (without actually having model_name at hand) to load the weights from the hf cache. I now modified the PR to use the information provided in the model_config/models.json as you suggested. In a next PR we should then alter torchchat/distributed/checkpoint_utils.py to leverage torchchat infra like (e.g. builder_args) to locate and access the files.

# TODO: Use information in builder_args directly to build model and load weights
assert builder_args.params_table
try:
distribution = NAME_TO_DISTRIBUTION[builder_args.params_table]
except KeyError as e:
print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat")
raise e

pp_degree = builder_args.pp
tp_degree = builder_args.tp

init_distributed()
rank = dist.get_rank()
torch.cuda.set_device(rank % torch.cuda.device_count())

logger = SingletonLogger.get_logger()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

# Model-level config
if builder_args.params_table:
model_config = ModelArgs.from_table(builder_args.params_table)
else:
raise NotImplementedError()
# Transformer-level config
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

#TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import (
load_model_weights,
)

# Validate pipeline degree
assert config.n_layers % pp_degree == 0

# Create device mesh
device_mesh = dist.init_device_mesh(
"cuda",
(pp_degree, tp_degree),
mesh_dim_names=("pp", "tp")
)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")

pp_rank = pp_mesh.get_local_rank()
logger.info(f"{pp_degree=}, {tp_degree=}")

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Fill in PP configs
config.stage_idx = pp_rank
config.n_stages = pp_degree

with torch.device("meta"):
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
# (Surprisingly, this works even though model is on meta device and mesh is of
# cuda devices)
model.distribute(tp_mesh)
if rank == 0:
logger.info(f"Model: {model}")

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
load_model_weights(model, distribution, device, config, builder_args.chpt_from)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Setup KV caches (after model distribution)
# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
seqlen_prefill=1024
with device:
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)

# info on stage size and params
# stage_size = get_module_size(model)
# stage_size_formatted = bytes_to_readable(stage_size)
# stage_num_params = get_num_params(model)
# logger.info(
# f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
# )
model.eval()

model.text_transformer_args = None
model.config.model_type = model_config.model_type
model.device_mesh = device_mesh
else:
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
Expand Down
32 changes: 32 additions & 0 deletions torchchat/distributed/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed._tensor import DTensor
from torchchat.distributed.dtensor_utils import convert_to_dtensor
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
from torchchat.model import ModelArgs


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
Expand Down Expand Up @@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
# Fill state dict into stage module
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")


def load_model_weights(
stage_module: torch.nn.Module,
distribution: str,
device: torch.device,
model_config: ModelArgs,
chpt_from: str,
):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.

Args:
stage_module (torch.nn.Module): The model stage to load the weights into.
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
"""
if chpt_from == "hf":
# This format stands for: index file + multiple binary files
load_weights_from_hf_format(stage_module, distribution, device, model_config)
elif chpt_from == "torchchat":
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(
stage_module, distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
Loading
Loading