Skip to content

Commit

Permalink
[python] [orttraining] Add utility to export a graph to compute gradi…
Browse files Browse the repository at this point in the history
…ents (microsoft#8125)
  • Loading branch information
juharris authored Feb 18, 2022
1 parent 6f0640a commit 742694f
Showing 9 changed files with 292 additions and 4 deletions.
14 changes: 14 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
@@ -281,6 +281,12 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_amp_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/amp/*.py"
)
file(GLOB onnxruntime_python_experimental_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/experimental/*.py"
)
file(GLOB onnxruntime_python_gradient_graph_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/experimental/gradient_graph/*.py"
)
file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/optim/*.py"
)
@@ -553,6 +559,8 @@ if (onnxruntime_ENABLE_TRAINING)
TARGET onnxruntime_pybind11_state POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/amp
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/experimental
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/experimental/gradient_graph
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/optim
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental
@@ -573,6 +581,12 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_amp_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/amp/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_experimental_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/experimental/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_gradient_graph_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/experimental/gradient_graph/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_optim_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/optim/
50 changes: 50 additions & 0 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
@@ -9,11 +9,14 @@
#include <pybind11/stl_bind.h>

#include "core/common/parse_string.h"
#include "core/graph/model.h"
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/agent/training_agent.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
#include "orttraining/core/framework/gradient_graph_builder.h"
#include "orttraining/core/framework/ortmodule_graph_builder.h"
#include "orttraining/core/graph/gradient_definition_registry.h"
#include "python/onnxruntime_pybind_mlvalue.h"
@@ -126,6 +129,15 @@ struct TrainingConfigurationResult {
optional<std::string> loss_scale_input_name;
};

struct PyGradientGraphBuilder {
std::unique_ptr<GradientGraphBuilder> builder;
std::shared_ptr<Model> model;
std::unique_ptr<logging::Logger> logger;
std::unique_ptr<GradientGraphConfiguration> gradient_graph_config;
PyGradientGraphBuilder(std::unique_ptr<GradientGraphBuilder> builder_, std::shared_ptr<Model> model_, std::unique_ptr<logging::Logger> logger_, std::unique_ptr<GradientGraphConfiguration> gradient_graph_config_)
: builder(std::move(builder_)), model(std::move(model_)), logger(std::move(logger_)), gradient_graph_config(std::move(gradient_graph_config_)) {}
};

// TODO: this method does not handle parallel optimization.
TrainingConfigurationResult ConfigureSessionForTraining(
training::PipelineTrainingSession* sess, TrainingParameters& parameters) {
@@ -760,6 +772,44 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
return ortmodule_graph_builder->GetGraphInfo();
});

// Provide a convenient and well-documented way to make a gradient graph.
// It's possible to get the gradient graph through ORTModule by leveraging some "private" fields and not-so-well-documented APIs, so we provide this explicit and tested way to get the gradient graph.
py::class_<PyGradientGraphBuilder> gradient_graph_builder(m, "GradientGraphBuilder", R"pbdoc(A utility for making a gradient graph that can be used to help train a model.)pbdoc");
// Set up methods to match the C++ `GradientGraphBuilder` interface.
gradient_graph_builder.def(py::init([](
const py::bytes& serialized_model,
const std::unordered_set<std::string>& y_node_arg_names,
const std::unordered_set<std::string>& x_node_arg_names,
const std::string loss_node_arg_name) {
std::shared_ptr<Model> model;
auto logger = logging::LoggingManager::DefaultLogger();
ONNX_NAMESPACE::ModelProto model_proto;
std::istringstream model_istream(serialized_model);
ORT_THROW_IF_ERROR(Model::Load(model_istream, &model_proto));
ORT_THROW_IF_ERROR(Model::Load(model_proto, model, nullptr, logger));
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.set_gradients_as_graph_outputs = true;
// Save some objects, otherwise they get lost.
auto gradient_graph_config_ptr = std::make_unique<GradientGraphConfiguration>(gradient_graph_config);
auto logger_ptr = std::make_unique<logging::Logger>(logger);

auto builder = std::make_unique<GradientGraphBuilder>(
&model->MainGraph(),
y_node_arg_names,
x_node_arg_names,
loss_node_arg_name,
*gradient_graph_config_ptr,
*logger_ptr);

return std::make_unique<PyGradientGraphBuilder>(std::move(builder), std::move(model), std::move(logger_ptr), std::move(gradient_graph_config_ptr));
}))
.def("build", [](PyGradientGraphBuilder* gradient_graph_builder) {
ORT_THROW_IF_ERROR(gradient_graph_builder->builder->Build());
})
.def("save", [](PyGradientGraphBuilder* gradient_graph_builder, const std::string& path) {
ORT_THROW_IF_ERROR(Model::Save(*(gradient_graph_builder->model), path));
});

py::class_<GradientNodeAttributeDefinition> gradient_node_attribute_definition(
m, "GradientNodeAttributeDefinition", R"pbdoc(Attribute definition for gradient graph nodes.)pbdoc");

6 changes: 3 additions & 3 deletions orttraining/orttraining/python/training/__init__.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,13 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from onnxruntime.capi._pybind_state import TrainingParameters
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy, TrainingParameters
from onnxruntime.capi.training.training_session import TrainingSession

# Options need to be imported before `ORTTrainer`.
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, optim, model_desc_validation
from . import amp, checkpoint, model_desc_validation, optim


try:
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gradient_graph._gradient_graph_tools import export_gradient_graph
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import io
from pathlib import Path
from typing import Any, Callable, Optional, Union

import torch
from onnxruntime.capi._pybind_state import GradientGraphBuilder
from torch.onnx import TrainingMode

from ...ortmodule._custom_op_symbolic_registry import CustomOpSymbolicRegistry


def export_gradient_graph(
model: torch.nn.Module,
loss_fn: Callable[[Any, Any], Any],
example_input: torch.Tensor,
example_labels: torch.Tensor,
gradient_graph_path: Union[Path, str],
opset_version=12) -> None:
r"""
Build a gradient graph for `model` so that you can output gradients in an inference session when given specific input and corresponding labels.
Args:
model (torch.nn.Module): A gradient graph will be built for this model.
loss_fn (Callable[[Any, Any], Any]): A function to compute the loss given the model's output and the `example_labels`.
Predefined loss functions such as `torch.nn.CrossEntropyLoss()` will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime Web, instead, use a custom Python method.
example_input (torch.Tensor): Example input that you would give your model for inference/prediction.
example_labels (torch.Tensor): The expected labels for `example_input`.
This could be the output of your model when given `example_input` but it might be different if your loss function expects labels to be different (e.g. when using cross entropy loss).
gradient_graph_path (Union[Path, str]): The path to where you would like to save the gradient graph.
opset_version (int): See `torch.onnx.export`.
"""

# Make sure that loss nodes that expect multiple outputs are set up.
CustomOpSymbolicRegistry.register_all()

if not isinstance(gradient_graph_path, str):
gradient_graph_path = str(gradient_graph_path)

class WrapperModule(torch.nn.Module):
def forward(self, model_input, expected_labels, *model_params):
for param, set_param in zip(model.parameters(), model_params):
param.data = set_param.data
output = model(model_input)
loss = loss_fn(output, expected_labels)
return output, loss

wrapped_model = WrapperModule()

dynamic_axes = {
'input': {0: 'batch_size', },
'labels': {0: 'batch_size', },
'output': {0: 'batch_size', },
}

args = (example_input, example_labels, *tuple(model.parameters()))
model_param_names = tuple(name for name, _ in model.named_parameters())
input_names = ['input', 'labels', *model_param_names]
nodes_needing_gradients = set(
name for name, param in model.named_parameters()
if param.requires_grad)

f = io.BytesIO()
torch.onnx.export(
wrapped_model, args,
f,
export_params=True,
opset_version=opset_version, do_constant_folding=False,
training=TrainingMode.TRAINING,
input_names=input_names,
output_names=['output', 'loss'],
dynamic_axes=dynamic_axes)

exported_model = f.getvalue()
builder = GradientGraphBuilder(exported_model,
{'loss'},
nodes_needing_gradients,
'loss')
builder.build()
builder.save(gradient_graph_path)
Original file line number Diff line number Diff line change
@@ -121,10 +121,20 @@ def run_ortmodule_experimental_json_config_tests(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_experimental_gradient_graph_tests(cwd, log):
log.debug("Running: Experimental Gradient Graph Export Tests")

command = [sys.executable, '-m', 'pytest', '-sv',
'orttraining_test_experimental_gradient_graph.py']

run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_data_sampler_tests(cwd, log):
log.debug('Running: Data sampler tests')

command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_sampler.py']
command = [sys.executable, '-m', 'pytest',
'-sv', 'orttraining_test_sampler.py']

run_subprocess(command, cwd=cwd, log=log).check_returncode()

@@ -161,6 +171,8 @@ def main():

run_data_sampler_tests(cwd, log)

run_experimental_gradient_graph_tests(cwd, log)

return 0


Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import unittest
from pathlib import Path

import numpy as np
import onnx
import onnxruntime
import torch
from onnxruntime.training.experimental import export_gradient_graph


class NeuralNet(torch.nn.Module):
r"""
Simple example model.
"""

def __init__(self,
input_size: int,
embedding_size: int,
hidden_size: int,
num_classes: int):
super(NeuralNet, self).__init__()

self.frozen_layer = torch.nn.Linear(
input_size, embedding_size, bias=False)
# Freeze a layer (mainly to test that gradients don't get output for it).
self.frozen_layer.requires_grad_(False)

self.fc1 = torch.nn.Linear(embedding_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)

def forward(self, x):
out = self.frozen_layer(x)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
return out


def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def binary_cross_entropy_loss(inp, target):
loss = -torch.sum(target * torch.log2(inp[:, 0]) +
(1-target) * torch.log2(inp[:, 1]))
return loss


class GradientGraphBuilderTest(unittest.TestCase):
def test_save(self):
# We need a custom loss function to load the graph in an InferenceSession in ONNX Runtime Web.
# You can still make the gradient graph with torch.nn.CrossEntropyLoss() and this test will pass.
loss_fn = binary_cross_entropy_loss
input_size = 10
model = NeuralNet(input_size=input_size, embedding_size=20, hidden_size=5,
num_classes=2)
directory_path = Path(os.path.dirname(__file__)).resolve()

gradient_graph_path = directory_path/'gradient_graph_model.onnx'

batch_size = 1
example_input = torch.randn(
batch_size, input_size, requires_grad=True)
example_labels = torch.tensor([1])

export_gradient_graph(
model, loss_fn, example_input, example_labels, gradient_graph_path)

onnx_model = onnx.load(str(gradient_graph_path))
onnx.checker.check_model(onnx_model)

# Expected inputs: input, labels, models parameters.
self.assertEqual(
1 + 1 + sum(1 for _ in model.parameters()), len(onnx_model.graph.input))

# Expected outputs: prediction, loss, and parameters with gradients.
self.assertEqual(
1 + 1 + sum(1 if p.requires_grad else 0 for p in model.parameters()), len(onnx_model.graph.output))

torch_out = model(example_input)

try:
ort_session = onnxruntime.InferenceSession(str(gradient_graph_path))
except ValueError:
# Sometimes it is required to pass the available providers.
from onnxruntime.capi import _pybind_state as C
available_providers = C.get_available_providers()
ort_session = onnxruntime.InferenceSession(str(gradient_graph_path), providers=available_providers)

ort_inputs = {
onnx_model.graph.input[0].name: to_numpy(example_input),
onnx_model.graph.input[1].name: to_numpy(example_labels),
}

for name, param in model.named_parameters():
ort_inputs[name] = to_numpy(param.data)

ort_outs = ort_session.run(None, ort_inputs)
onnx_output_names = [node.name for node in onnx_model.graph.output]
onnx_name_to_output = dict(zip(onnx_output_names, ort_outs))

ort_output = onnx_name_to_output['output']
np.testing.assert_allclose(
to_numpy(torch_out), ort_output, rtol=1e-03, atol=1e-05)

torch_loss = loss_fn(torch_out, example_labels)
ort_loss = onnx_name_to_output['loss']
np.testing.assert_allclose(
to_numpy(torch_loss), ort_loss, rtol=1e-03, atol=1e-05)

# Make sure the gradients have the right shape.
model_param_names = tuple(
name for name, param in model.named_parameters() if param.requires_grad)
self.assertEqual(4, len(model_param_names))

for name, param in model.named_parameters():
if param.requires_grad:
grad = onnx_name_to_output[name + '_grad']
self.assertEqual(param.size(), grad.shape)


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -388,6 +388,8 @@ def run(self):
if enable_training:
packages.extend(['onnxruntime.training',
'onnxruntime.training.amp',
'onnxruntime.training.experimental',
'onnxruntime.training.experimental.gradient_graph',
'onnxruntime.training.optim',
'onnxruntime.training.ortmodule',
'onnxruntime.training.ortmodule.experimental',

0 comments on commit 742694f

Please sign in to comment.