This repository is a PyTorch and PyTorch Lightning implementation of
Perceiver: General Perception with Iterative Attention (paper, video) | |
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) | |
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) |
The codebase is modular and designed for easy extension to new tasks and datasets. The integration with PyTorch Lightning supports model training at scale. The command line interface is implemented with the Lightning CLI.
Pretrained models can be imported from the 🤗 Hub. Datasets used for model training are 🤗 Datasets wrapped into PyTorch Lightning data modules. For NLP tasks, this library also supports 🤗 fast tokenizers and the 🤗 Perceiver UTF-8 bytes tokenizer.
pip install perceiver-io[image,text]
Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.
conda env create -f environment.yml
conda activate perceiver-io
poetry install --all-extras
docker pull ghcr.io/krasserm/perceiver-io:latest
See Docker image for details.
Here's a minimal example for autoregressive language modeling with Perceiver AR. A small language model (30.7M parameters) is trained on the WikiText-103-raw dataset and then used to generate text from a prompt. Input text is tokenized into raw UTF-8 bytes, the model also predicts the raw UTF-8 bytes of generated text. More details about Perceiver AR and Perceiver IO model construction, training and inference are covered in the documentation.
The command line interface is implemented with Lightning CLI. Model training can be started with:
python -m perceiver.scripts.text.clm fit \
--model.num_latents=512 \
--model.num_channels=512 \
--model.num_self_attention_layers=8 \
--model.cross_attention_dropout=0.5 \
--data=WikiTextDataModule \
--data.tokenizer=deepmind/language-perceiver \
--data.max_seq_len=4096 \
--data.batch_size=16 \
--data.task=clm \
--optimizer=Adam \
--optimizer.lr=2e-4 \
--trainer.max_steps=5000 \
--trainer.accelerator=gpu \
--trainer.devices=1 \
--trainer.accumulate_grad_batches=4
You can also do this programmatically with the PyTorch Lightning Trainer
:
from torch.optim import Adam
from perceiver.data.text.wikitext import WikiTextDataModule, Task
from perceiver.model.text.clm import LitCausalLanguageModel, CausalLanguageModelConfig
import pytorch_lightning as pl
# Lightning WikiText data module
data = WikiTextDataModule(
tokenizer="deepmind/language-perceiver",
max_seq_len=4096,
batch_size=16,
task=Task.clm,
)
# Language model configuration object
model_config = CausalLanguageModelConfig(
vocab_size=data.vocab_size,
max_seq_len=data.max_seq_len,
num_latents=512,
num_channels=512,
num_self_attention_layers=8,
cross_attention_dropout=0.5,
)
def configure_optimizers(self):
return Adam(self.parameters(), lr=2e-4)
# Associate optimizer factory with Lightning module (not predefined there)
setattr(LitCausalLanguageModel, "configure_optimizers", configure_optimizers),
# Lightning module of language model (a Perceiver AR)
lit_model = LitCausalLanguageModel.create(model_config)
# Instantiate Lightning Trainer
trainer = pl.Trainer(accelerator="gpu", devices=1, max_steps=5000, accumulate_grad_batches=4)
# Train model (will also preprocess dataset if used for the first time)
trainer.fit(lit_model, datamodule=data)
If you instead want to use plain PyTorch (without PyTorch Lightning, except for data sources):
from perceiver.model.text.clm import CausalLanguageModel
import torch.nn.functional as F
from torch.optim import Adam
data = ...
data.prepare_data()
data.setup()
model_config = ...
# Plain PyTorch module of language model
model = CausalLanguageModel(config=model_config)
model.train()
optim = Adam(model.parameters(), lr=2e-4)
# Simplified training loop compared to previous examples
# (no gradient accumulation, epochs instead of max_steps, ...)
for epoch in range(4):
for labels_ids, input_ids, _ in data.train_dataloader():
logits = model(input_ids)
loss = F.cross_entropy(logits.permute(0, 2, 1), labels_ids[:, -model_config.num_latents:])
loss.backward()
optim.step()
optim.zero_grad()
from perceiver.model.text.clm import LitCausalLanguageModel
data = ...
# Load Lightning module from training checkpoint
lit_model = LitCausalLanguageModel.load_from_checkpoint("/path/to/checkpoint")
# Obtain trained plain PyTorch model
model = lit_model.model.eval()
# Get text preprocessor from data module
preproc = data.text_preprocessor()
# Tokenize a sample prompt
prompt, _ = preproc.preprocess("A man was reading a book on a sunny day until he sudden")
# Generate tokens from prompt via top-k sampling where k = f(vocab_size, threshold)
generated = model.generate(num=512, prompt=prompt[None, ...], threshold=0.9)[0]
# Decode generated tokens
generated_text = data.tokenizer.decode(generated)
You can also run text generation interactively in the Colab notebook.