Created
February 9, 2022 13:18
-
-
Save pashu123/4e5ad4147023fdfb6624d0a23d3e24e7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM, | |
AutoModelForMaskedLM, | |
AutoModelForSeq2SeqLM, | |
ReformerConfig, | |
BigBirdConfig, | |
BertConfig, | |
) | |
import transformers | |
import torch | |
from functorch.compile import ( | |
memory_efficient_fusion, | |
aot_module, | |
draw_graph_compile, | |
nop, | |
min_cut_rematerialization_partition, | |
) | |
import torch.utils._pytree as pytree | |
import time | |
from torch import optim, fx | |
import torch.nn as nn | |
from torch.nn.utils import _stateless | |
from typing import List | |
#### torch-mlir imports | |
# To run the example, make sure the following are in your PYTHONPATH: | |
# 1. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir | |
import torch_mlir | |
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( | |
RefBackendLinalgOnTensorsBackend, | |
) | |
from torch_mlir.passmanager import PassManager | |
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder | |
from annotations import forward_annotations, backward_annotations | |
pytree._register_pytree_node( | |
transformers.modeling_outputs.MaskedLMOutput, | |
lambda x: ([x.loss, x.logits], None), | |
lambda values, _: transformers.modeling_outputs.MaskedLMOutput( | |
loss=values[1], logits=values[1] | |
), | |
) | |
pytree._register_pytree_node( | |
transformers.modeling_outputs.Seq2SeqLMOutput, | |
lambda x: ([x.loss, x.logits], None), | |
lambda values, _: transformers.modeling_outputs.Seq2SeqLMOutput( | |
loss=values[0], logits=values[1] | |
), | |
) | |
pytree._register_pytree_node( | |
transformers.modeling_outputs.CausalLMOutputWithCrossAttentions, | |
lambda x: ([x.loss, x.logits], None), | |
lambda values, _: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions( | |
loss=values[0], logits=values[1] | |
), | |
) | |
pytree._register_pytree_node( | |
transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput, | |
lambda x: ([x.loss, x.logits], None), | |
lambda values, _: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput( | |
loss=values[0], logits=values[1] | |
), | |
) | |
torch.manual_seed(42) | |
config = BertConfig() | |
model_type = AutoModelForMaskedLM | |
input_size = (4, 512) | |
device = "cpu" | |
dtype = torch.float | |
model = model_type.from_config(config).to(device, dtype=dtype) | |
input_ids = torch.randint(0, config.vocab_size, input_size).to(device) | |
decoder_ids = torch.randint(0, config.vocab_size, input_size).to(device) | |
train_inputs = {"input_ids": input_ids, "labels": decoder_ids} | |
def bench_model(mod): | |
iters = 1 | |
for _ in range(iters): | |
mod(**train_inputs).loss.sum().backward() | |
def save_mlir(module, name): | |
with open(name, "w") as text_file: | |
text_file.write(str(module)) | |
def generate_mlir_module(ts_script_module, annotations, name): | |
mb = ModuleBuilder() | |
class_annotator = ClassAnnotator() | |
class_annotator.exportNone(ts_script_module._c._type()) | |
class_annotator.exportPath(ts_script_module._c._type(), ["forward"]) | |
class_annotator.annotateArgs( | |
ts_script_module._c._type(), ["forward"], annotations | |
) | |
mb.import_module(ts_script_module._c, class_annotator) | |
mlir_module = mb.module | |
save_mlir(mlir_module, name) | |
# mlir_module.dump() | |
def generate_annotation(inputs): | |
annotations_list = [None] | |
for i in inputs: | |
temp_list = [] | |
temp_list.append(list(i.shape)) | |
temp_list.append(i.dtype) | |
temp_list.append(True) | |
annotations_list.append(tuple(temp_list)) | |
return annotations_list | |
def ts_compiler_forward(fx_g: fx.GraphModule, inps): | |
f = torch.jit.script(fx_g) | |
f = torch.jit.freeze(f.eval()) | |
torch.jit.save(f, "/home/prashant/forw.pt") | |
f = torch.jit.load("/home/prashant/forw.pt") | |
generate_mlir_module( | |
f, generate_annotation(inps), "forward.mlir" | |
) | |
return f | |
def ts_compiler_backward(fx_g: fx.GraphModule, inps): | |
f = torch.jit.script(fx_g) | |
f = torch.jit.freeze(f.eval()) | |
torch.jit.save(f, "/home/prashant/back.pt") | |
f = torch.jit.load("/home/prashant/back.pt") | |
generate_mlir_module( | |
f, generate_annotation(inps), "backward.mlir" | |
) | |
return f | |
aot_model = aot_module( | |
model, | |
fw_compiler=ts_compiler_forward, | |
bw_compiler=ts_compiler_backward, | |
partition_fn=min_cut_rematerialization_partition, | |
) | |
bench_model(aot_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment