Skip to content

Commit

Permalink
Update torchao.md: use auto-compilation (#35490)
Browse files Browse the repository at this point in the history
* Update torchao.md: use auto-compilation

* Update torchao.md: indicate updating transformers to the latest

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
  • Loading branch information
martin0258 and SunMarc authored Jan 14, 2025
1 parent 4b8d1f7 commit 715fdd6
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ rendered properly in your Markdown viewer.
Before you begin, make sure the following libraries are installed with their latest version:

```bash
pip install --upgrade torch torchao
# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
pip install --upgrade torch torchao transformers
```

By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
Expand All @@ -35,12 +36,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# compile the quantized model to get speedup
import torchao
torchao.quantization.utils.recommended_inductor_config_setter()
quantized_model = torch.compile(quantized_model, mode="max-autotune")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))

# benchmark the performance
Expand All @@ -59,11 +56,11 @@ def benchmark_fn(f, *args, **kwargs):
return f"{(t0.blocked_autorange().mean):.3f}"

MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = torch.compile(bf16_model, mode="max-autotune")
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

```

Expand Down

0 comments on commit 715fdd6

Please sign in to comment.