Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 9, 2022 13:18
Show Gist options
  • Save pashu123/4e5ad4147023fdfb6624d0a23d3e24e7 to your computer and use it in GitHub Desktop.
Save pashu123/4e5ad4147023fdfb6624d0a23d3e24e7 to your computer and use it in GitHub Desktop.
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
ReformerConfig,
BigBirdConfig,
BertConfig,
)
import transformers
import torch
from functorch.compile import (
memory_efficient_fusion,
aot_module,
draw_graph_compile,
nop,
min_cut_rematerialization_partition,
)
import torch.utils._pytree as pytree
import time
from torch import optim, fx
import torch.nn as nn
from torch.nn.utils import _stateless
from typing import List
#### torch-mlir imports
# To run the example, make sure the following are in your PYTHONPATH:
# 1. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
from torch_mlir.passmanager import PassManager
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from annotations import forward_annotations, backward_annotations
pytree._register_pytree_node(
transformers.modeling_outputs.MaskedLMOutput,
lambda x: ([x.loss, x.logits], None),
lambda values, _: transformers.modeling_outputs.MaskedLMOutput(
loss=values[1], logits=values[1]
),
)
pytree._register_pytree_node(
transformers.modeling_outputs.Seq2SeqLMOutput,
lambda x: ([x.loss, x.logits], None),
lambda values, _: transformers.modeling_outputs.Seq2SeqLMOutput(
loss=values[0], logits=values[1]
),
)
pytree._register_pytree_node(
transformers.modeling_outputs.CausalLMOutputWithCrossAttentions,
lambda x: ([x.loss, x.logits], None),
lambda values, _: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions(
loss=values[0], logits=values[1]
),
)
pytree._register_pytree_node(
transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput,
lambda x: ([x.loss, x.logits], None),
lambda values, _: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput(
loss=values[0], logits=values[1]
),
)
torch.manual_seed(42)
config = BertConfig()
model_type = AutoModelForMaskedLM
input_size = (4, 512)
device = "cpu"
dtype = torch.float
model = model_type.from_config(config).to(device, dtype=dtype)
input_ids = torch.randint(0, config.vocab_size, input_size).to(device)
decoder_ids = torch.randint(0, config.vocab_size, input_size).to(device)
train_inputs = {"input_ids": input_ids, "labels": decoder_ids}
def bench_model(mod):
iters = 1
for _ in range(iters):
mod(**train_inputs).loss.sum().backward()
def save_mlir(module, name):
with open(name, "w") as text_file:
text_file.write(str(module))
def generate_mlir_module(ts_script_module, annotations, name):
mb = ModuleBuilder()
class_annotator = ClassAnnotator()
class_annotator.exportNone(ts_script_module._c._type())
class_annotator.exportPath(ts_script_module._c._type(), ["forward"])
class_annotator.annotateArgs(
ts_script_module._c._type(), ["forward"], annotations
)
mb.import_module(ts_script_module._c, class_annotator)
mlir_module = mb.module
save_mlir(mlir_module, name)
# mlir_module.dump()
def generate_annotation(inputs):
annotations_list = [None]
for i in inputs:
temp_list = []
temp_list.append(list(i.shape))
temp_list.append(i.dtype)
temp_list.append(True)
annotations_list.append(tuple(temp_list))
return annotations_list
def ts_compiler_forward(fx_g: fx.GraphModule, inps):
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "/home/prashant/forw.pt")
f = torch.jit.load("/home/prashant/forw.pt")
generate_mlir_module(
f, generate_annotation(inps), "forward.mlir"
)
return f
def ts_compiler_backward(fx_g: fx.GraphModule, inps):
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "/home/prashant/back.pt")
f = torch.jit.load("/home/prashant/back.pt")
generate_mlir_module(
f, generate_annotation(inps), "backward.mlir"
)
return f
aot_model = aot_module(
model,
fw_compiler=ts_compiler_forward,
bw_compiler=ts_compiler_backward,
partition_fn=min_cut_rematerialization_partition,
)
bench_model(aot_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment