Skip to content

Commit

Permalink
add torch.compile by default, shows almost 1.8X improvement in throug…
Browse files Browse the repository at this point in the history
…hput nice
  • Loading branch information
karpathy committed Dec 30, 2022
1 parent fb52554 commit 5a725d9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ The simplest, fastest repository for training/finetuning medium-sized GPTs. It's
Dependencies:

- [pytorch](https://pytorch.org) <3
- numpy <3
- `pip install datasets` for huggingface datasets <3
- `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
- `pip install tiktoken` for OpenAI's fast bpe code <3
- `pip install wandb` for optional logging <3

Expand Down Expand Up @@ -68,6 +67,10 @@ I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic

For model benchmarking `bench.py` might be useful. It's identical what happens in the meat of the training loop of `train.py`, but omits much of the other complexities.

# efficiency notes

Code by default now uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team!

## todos

A few that I'm aware of, other than the ones mentioned in code:
Expand Down
7 changes: 6 additions & 1 deletion bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

batch_size = 8
block_size = 1024
dtype = torch.float16
dtype = torch.bfloat16
compile_model = True

# data loading init
real_data = True
Expand Down Expand Up @@ -46,6 +47,10 @@ def get_batch(split):

optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95))

if compile_model:
print("Compiling model...")
model = torch.compile(model) # pytorch 2.0

profile = False # use pytorch profiler, or just simple benchmarking?
if profile:
# useful docs on pytorch profiler:
Expand Down
1 change: 1 addition & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
model.load_state_dict(checkpoint['model'])
model.eval()
model.to(device)
model = torch.compile(model) # requires PyTorch 2.0

enc = tiktoken.get_encoding("gpt2")
#start = enc.encode("\n")
Expand Down
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
min_lr = 1e-5 # minimum learning rate
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
compile_model = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage:
# $ python train.py override_file --batch_size=32
Expand Down Expand Up @@ -156,6 +157,12 @@ def get_batch(split):
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])

# compile the model
if compile_model:
print("compiling the model... (takes a ~minute)")
unoptimized_model = model
model = torch.compile(model) # requires PyTorch 2.0

# wrap model into DDP container
if ddp:
model = DDP(model, device_ids=[gpu_id])
Expand Down

0 comments on commit 5a725d9

Please sign in to comment.