forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[python] [orttraining] Add utility to export a graph to compute gradi…
…ents (microsoft#8125)
- Loading branch information
Showing
9 changed files
with
292 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
orttraining/orttraining/python/training/experimental/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gradient_graph._gradient_graph_tools import export_gradient_graph |
Empty file.
84 changes: 84 additions & 0 deletions
84
orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import io | ||
from pathlib import Path | ||
from typing import Any, Callable, Optional, Union | ||
|
||
import torch | ||
from onnxruntime.capi._pybind_state import GradientGraphBuilder | ||
from torch.onnx import TrainingMode | ||
|
||
from ...ortmodule._custom_op_symbolic_registry import CustomOpSymbolicRegistry | ||
|
||
|
||
def export_gradient_graph( | ||
model: torch.nn.Module, | ||
loss_fn: Callable[[Any, Any], Any], | ||
example_input: torch.Tensor, | ||
example_labels: torch.Tensor, | ||
gradient_graph_path: Union[Path, str], | ||
opset_version=12) -> None: | ||
r""" | ||
Build a gradient graph for `model` so that you can output gradients in an inference session when given specific input and corresponding labels. | ||
Args: | ||
model (torch.nn.Module): A gradient graph will be built for this model. | ||
loss_fn (Callable[[Any, Any], Any]): A function to compute the loss given the model's output and the `example_labels`. | ||
Predefined loss functions such as `torch.nn.CrossEntropyLoss()` will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime Web, instead, use a custom Python method. | ||
example_input (torch.Tensor): Example input that you would give your model for inference/prediction. | ||
example_labels (torch.Tensor): The expected labels for `example_input`. | ||
This could be the output of your model when given `example_input` but it might be different if your loss function expects labels to be different (e.g. when using cross entropy loss). | ||
gradient_graph_path (Union[Path, str]): The path to where you would like to save the gradient graph. | ||
opset_version (int): See `torch.onnx.export`. | ||
""" | ||
|
||
# Make sure that loss nodes that expect multiple outputs are set up. | ||
CustomOpSymbolicRegistry.register_all() | ||
|
||
if not isinstance(gradient_graph_path, str): | ||
gradient_graph_path = str(gradient_graph_path) | ||
|
||
class WrapperModule(torch.nn.Module): | ||
def forward(self, model_input, expected_labels, *model_params): | ||
for param, set_param in zip(model.parameters(), model_params): | ||
param.data = set_param.data | ||
output = model(model_input) | ||
loss = loss_fn(output, expected_labels) | ||
return output, loss | ||
|
||
wrapped_model = WrapperModule() | ||
|
||
dynamic_axes = { | ||
'input': {0: 'batch_size', }, | ||
'labels': {0: 'batch_size', }, | ||
'output': {0: 'batch_size', }, | ||
} | ||
|
||
args = (example_input, example_labels, *tuple(model.parameters())) | ||
model_param_names = tuple(name for name, _ in model.named_parameters()) | ||
input_names = ['input', 'labels', *model_param_names] | ||
nodes_needing_gradients = set( | ||
name for name, param in model.named_parameters() | ||
if param.requires_grad) | ||
|
||
f = io.BytesIO() | ||
torch.onnx.export( | ||
wrapped_model, args, | ||
f, | ||
export_params=True, | ||
opset_version=opset_version, do_constant_folding=False, | ||
training=TrainingMode.TRAINING, | ||
input_names=input_names, | ||
output_names=['output', 'loss'], | ||
dynamic_axes=dynamic_axes) | ||
|
||
exported_model = f.getvalue() | ||
builder = GradientGraphBuilder(exported_model, | ||
{'loss'}, | ||
nodes_needing_gradients, | ||
'loss') | ||
builder.build() | ||
builder.save(gradient_graph_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
import unittest | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import onnx | ||
import onnxruntime | ||
import torch | ||
from onnxruntime.training.experimental import export_gradient_graph | ||
|
||
|
||
class NeuralNet(torch.nn.Module): | ||
r""" | ||
Simple example model. | ||
""" | ||
|
||
def __init__(self, | ||
input_size: int, | ||
embedding_size: int, | ||
hidden_size: int, | ||
num_classes: int): | ||
super(NeuralNet, self).__init__() | ||
|
||
self.frozen_layer = torch.nn.Linear( | ||
input_size, embedding_size, bias=False) | ||
# Freeze a layer (mainly to test that gradients don't get output for it). | ||
self.frozen_layer.requires_grad_(False) | ||
|
||
self.fc1 = torch.nn.Linear(embedding_size, hidden_size) | ||
self.relu = torch.nn.ReLU() | ||
self.fc2 = torch.nn.Linear(hidden_size, num_classes) | ||
|
||
def forward(self, x): | ||
out = self.frozen_layer(x) | ||
out = self.fc1(out) | ||
out = self.relu(out) | ||
out = self.fc2(out) | ||
return out | ||
|
||
|
||
def to_numpy(tensor): | ||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | ||
|
||
|
||
def binary_cross_entropy_loss(inp, target): | ||
loss = -torch.sum(target * torch.log2(inp[:, 0]) + | ||
(1-target) * torch.log2(inp[:, 1])) | ||
return loss | ||
|
||
|
||
class GradientGraphBuilderTest(unittest.TestCase): | ||
def test_save(self): | ||
# We need a custom loss function to load the graph in an InferenceSession in ONNX Runtime Web. | ||
# You can still make the gradient graph with torch.nn.CrossEntropyLoss() and this test will pass. | ||
loss_fn = binary_cross_entropy_loss | ||
input_size = 10 | ||
model = NeuralNet(input_size=input_size, embedding_size=20, hidden_size=5, | ||
num_classes=2) | ||
directory_path = Path(os.path.dirname(__file__)).resolve() | ||
|
||
gradient_graph_path = directory_path/'gradient_graph_model.onnx' | ||
|
||
batch_size = 1 | ||
example_input = torch.randn( | ||
batch_size, input_size, requires_grad=True) | ||
example_labels = torch.tensor([1]) | ||
|
||
export_gradient_graph( | ||
model, loss_fn, example_input, example_labels, gradient_graph_path) | ||
|
||
onnx_model = onnx.load(str(gradient_graph_path)) | ||
onnx.checker.check_model(onnx_model) | ||
|
||
# Expected inputs: input, labels, models parameters. | ||
self.assertEqual( | ||
1 + 1 + sum(1 for _ in model.parameters()), len(onnx_model.graph.input)) | ||
|
||
# Expected outputs: prediction, loss, and parameters with gradients. | ||
self.assertEqual( | ||
1 + 1 + sum(1 if p.requires_grad else 0 for p in model.parameters()), len(onnx_model.graph.output)) | ||
|
||
torch_out = model(example_input) | ||
|
||
try: | ||
ort_session = onnxruntime.InferenceSession(str(gradient_graph_path)) | ||
except ValueError: | ||
# Sometimes it is required to pass the available providers. | ||
from onnxruntime.capi import _pybind_state as C | ||
available_providers = C.get_available_providers() | ||
ort_session = onnxruntime.InferenceSession(str(gradient_graph_path), providers=available_providers) | ||
|
||
ort_inputs = { | ||
onnx_model.graph.input[0].name: to_numpy(example_input), | ||
onnx_model.graph.input[1].name: to_numpy(example_labels), | ||
} | ||
|
||
for name, param in model.named_parameters(): | ||
ort_inputs[name] = to_numpy(param.data) | ||
|
||
ort_outs = ort_session.run(None, ort_inputs) | ||
onnx_output_names = [node.name for node in onnx_model.graph.output] | ||
onnx_name_to_output = dict(zip(onnx_output_names, ort_outs)) | ||
|
||
ort_output = onnx_name_to_output['output'] | ||
np.testing.assert_allclose( | ||
to_numpy(torch_out), ort_output, rtol=1e-03, atol=1e-05) | ||
|
||
torch_loss = loss_fn(torch_out, example_labels) | ||
ort_loss = onnx_name_to_output['loss'] | ||
np.testing.assert_allclose( | ||
to_numpy(torch_loss), ort_loss, rtol=1e-03, atol=1e-05) | ||
|
||
# Make sure the gradients have the right shape. | ||
model_param_names = tuple( | ||
name for name, param in model.named_parameters() if param.requires_grad) | ||
self.assertEqual(4, len(model_param_names)) | ||
|
||
for name, param in model.named_parameters(): | ||
if param.requires_grad: | ||
grad = onnx_name_to_output[name + '_grad'] | ||
self.assertEqual(param.size(), grad.shape) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters