Skip to content
This repository has been archived by the owner on Jun 21, 2024. It is now read-only.

Commit

Permalink
update documentation to include CPU example
Browse files Browse the repository at this point in the history
  • Loading branch information
conceptofmind committed May 16, 2023
1 parent c7d61b1 commit d0b0791
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
## Acknowledgements
- <a href="https://github.com/CarperAI">CarperAI</a>, <a href="https://twitter.com/lcastricato">Louis Castricato</a>, and <a href="https://stability.ai/">Stability.ai</a> for the very generous sponsorship to work on machine learning research.
- <a href="https://github.com/lucidrains">Phil Wang (Lucidrains)</a> for his inspiring work and input on training and architectures.
- <a href="https://twitter.com/dmayhem93">Dakota ("He berk reacted once")</a>, <a href="https://twitter.com/jonbtow">Guac</a>, <a href="https://twitter.com/zach_nussbaum">Zach</a>, and <a href="">Aman</a> for providing information about Huggingface and Slurm. I typically only use Apex and DeepSpeed.
- <a href="https://twitter.com/dmayhem93">Dakota ("He berk reacted once")</a>, <a href="https://twitter.com/jonbtow">Guac</a>, <a href="https://twitter.com/zach_nussbaum">Zach</a>, and <a href="https://twitter.com/aman_gif">Aman</a> for providing information about Huggingface and Slurm. I typically only use Apex and DeepSpeed.

## FAQ
Three different size PaLM models (150m, 410m, 1b) have been trained with 8k context length on all of <a href="https://huggingface.co/datasets/c4">C4</a>. The models are compatible with Lucidrain's <a href="https://github.com/lucidrains/toolformer-pytorch">Toolformer-pytorch</a>, <a href="https://github.com/lucidrains/PaLM-pytorch">PaLM-pytorch</a>, and <a href="https://github.com/lucidrains/PaLM-rlhf-pytorch">PaLM-rlhf-pytorch</a>. A fourth 2b model is currently being trained. These are currently the baseline versions of the models and additional training will be done at a larger scale. All of the models will be further instruction-tuned on FLAN to provide flan-PaLM models.
Expand Down Expand Up @@ -36,6 +36,17 @@ model = PaLM(

model.load('/palm_410m_8k_v0.pt')
```
If you would like to use the models on CPU you can do:
```python
device = torch.device("cpu")

model = PaLM(
num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
).to(device).eval()

checkpoint = torch.load('./palm_410m_8k_v0.pt', map_location=device)
model.load_state_dict(checkpoint)
```
To generate text with the models you can use the command line:
- prompt - Text prompt to generate text.
- seq_len - Sequence length for generated text. Default is 256.
Expand Down

0 comments on commit d0b0791

Please sign in to comment.