Skip to content

Commit

Permalink
add support for symbolic broadcast for Add/Sub/Mul (microsoft#3743)
Browse files Browse the repository at this point in the history
* add support for symbolic broadcast

* fix comment

* address feedback
  • Loading branch information
xzhu1900 authored May 6, 2020
1 parent 687edd7 commit 0e59668
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 21 deletions.
37 changes: 29 additions & 8 deletions onnxruntime/test/providers/provider_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,19 @@ class OpTester {
// bool and we can't get the raw data out. So those cases must use an initializer_list
template <typename T>
void AddInput(const char* name, const std::vector<int64_t>& dims, const std::initializer_list<T>& values,
bool is_initializer = false) {
AddData(input_data_, name, dims, values.begin(), values.size(), is_initializer);
bool is_initializer = false, const std::vector<std::string>* dim_params = nullptr) {
AddData(input_data_, name, dims, values.begin(), values.size(), is_initializer, false, dim_params);
}

template <typename T>
void AddInput(const char* name, const std::vector<int64_t>& dims, const std::vector<T>& values,
bool is_initializer = false) {
AddData(input_data_, name, dims, values.data(), values.size(), is_initializer);
bool is_initializer = false, const std::vector<std::string>* dim_params = nullptr) {
AddData(input_data_, name, dims, values.data(), values.size(), is_initializer, false, dim_params);
}

template <typename T>
void AddInput(const char* name, const std::vector<int64_t>& dims, const T* p_values, const size_t size, bool is_initializer = false) {
AddData(input_data_, name, dims, p_values, size, is_initializer);
void AddInput(const char* name, const std::vector<int64_t>& dims, const T* p_values, const size_t size, bool is_initializer = false, const std::vector<std::string>* dim_params = nullptr) {
AddData(input_data_, name, dims, p_values, size, is_initializer, false, dim_params);
}

// Add other registered types, possibly experimental
Expand Down Expand Up @@ -505,7 +505,8 @@ class OpTester {
protected:
template <typename T>
void AddData(std::vector<Data>& data, const char* name, const std::vector<int64_t>& dims, const T* values,
int64_t values_count, bool is_initializer = false, bool sort_output = false) {
int64_t values_count, bool is_initializer = false, bool sort_output = false,
const std::vector<std::string>* dim_params = nullptr) {
try {
TensorShape shape{dims};
ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ",
Expand All @@ -529,7 +530,27 @@ class OpTester {
OrtValue value;
value.Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
data.push_back(Data(NodeArg(name, &type_proto), std::move(value), optional<float>(), optional<float>(), sort_output));
auto node_arg = NodeArg(name, &type_proto);
if (dim_params && !(dim_params->empty())) {
// If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast.
auto& dim_params_data = *dim_params;
onnx::TensorShapeProto new_shape;

// currently hard-code the reserved symbolic names.
// TODO: when the list grows longer, consider move it to a better place.
const static std::unordered_set<std::string> reserved_symbolic{"batch", "seq"};

for (size_t i = 0; i < dim_params_data.size(); ++i) {
if (reserved_symbolic.find(dim_params_data[i])!= reserved_symbolic.end()) {
new_shape.add_dim()->set_dim_param(dim_params_data[i]);
} else {
ASSERT_TRUE(std::stoi(dim_params_data[i]) == dims[i]);
new_shape.add_dim()->set_dim_value(dims[i]);
}
}
node_arg.SetShape(new_shape);
}
data.push_back(Data(std::move(node_arg), std::move(value), optional<float>(), optional<float>(), sort_output));
if (is_initializer) initializer_index_.push_back(data.size() - 1);
} catch (const std::exception& ex) {
std::cerr << "AddData for '" << name << "' threw: " << ex.what();
Expand Down
26 changes: 24 additions & 2 deletions orttraining/orttraining/core/graph/gradient_builder_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,32 @@ void ComputeBroadcastBackwardAxes(
auto A_dim = A_dims[i].dim_param(),
B_dim = B_dims[j].dim_param();
if (A_dim != B_dim) {
ORT_THROW("Error");
ORT_THROW("Error: symbolic dimension doesn't match. Expect the same symbolic but got \"",
A_dim, "\" and \"", B_dim, "\".");
}
} else if (A_dims[i].has_dim_param() && B_dims[j].has_dim_value()) {
auto A_dim = A_dims[i].dim_param();
auto B_dim = B_dims[j].dim_value();

if (B_dim != 1) {
ORT_THROW("Error: symbolic broadcasting requires the corresponding dimension to be 1. ",
"Actually got ", B_dim);
}
if (B_axes) {
B_axes->push_back(gsl::narrow_cast<int64_t>(k));
}
} else if (A_dims[i].has_dim_value() && B_dims[j].has_dim_param()) {
auto A_dim = A_dims[j].dim_value();
auto B_dim = B_dims[i].dim_param();

if (A_dim != 1) {
ORT_THROW("Error: symbolic broadcasting requires the corresponding dimension to be 1. ",
"Actually got ", A_dim);
}
if (A_axes) {
A_axes->push_back(gsl::narrow_cast<int64_t>(k));
}
}
// TODO : complete othere cases

--i;
--j;
Expand Down
31 changes: 24 additions & 7 deletions orttraining/orttraining/test/gradient/gradient_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp

// Compute the theoretical Jacobians one row at a time by back propagating
// '1.0' for each element of 'dy', while holding all other elements of 'dy' at zero.
for (int c = 0; c < dy_size; ++c) { // for each value in the dy input vector
for (size_t c = 0; c < dy_size; ++c) { // for each value in the dy input vector
// clear OpTester input/output/initializer
op_session.ClearData();

Expand Down Expand Up @@ -167,7 +167,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
// inputs is treated as a vector of vectors. The parameters of the function call below, y_idx and c
// corresponding to which input (dy1, dy2..etc) and which value of the input (dy_flattened_vector[c]]
// to pertrub to 1.
op_session.Run(y_idx, c);
op_session.Run(y_idx, static_cast<int>(c));
auto gradients = op_session.GetFetches();

for (int x_idx = 0, grad_idx = 0; x_idx < static_cast<int>(x_num); x_idx++) {
Expand All @@ -186,7 +186,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
r,
y_infos,
y_idx,
c);
static_cast<int>(c));
(*jacobian_ts)[calc_index.first][calc_index.second] = dx_flat[r];
}
}
Expand All @@ -211,19 +211,36 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<int64_t>()) {
std::vector<int64_t> int64_data(data.size());
std::transform(data.begin(), data.end(), int64_data.begin(), [](X_T x) { return static_cast<int64_t>(x); });
op_session.AddInput<int64_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int64_data);
op_session.AddInput<int64_t>(name.c_str(),
x_infos[data_index].shape.GetDims(),
int64_data,
false,
&x_infos[data_index].dim_params);
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<int32_t>()) {
std::vector<int32_t> int32_data(data.size());
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
op_session.AddInput<int32_t>(name.c_str(),
x_infos[data_index].shape.GetDims(),
int32_data,
false,
&x_infos[data_index].dim_params);
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
for (size_t i = 0; i < data.size(); ++i) {
p_data[i] = static_cast<bool>(data[i]);
}
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
op_session.AddInput<bool>(name.c_str(),
x_infos[data_index].shape.GetDims(),
p_data.get(),
data.size(),
false,
&x_infos[data_index].dim_params);
} else {
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
op_session.AddInput<X_T>(name.c_str(),
x_infos[data_index].shape.GetDims(),
data,
false,
&x_infos[data_index].dim_params);
}
}

Expand Down
12 changes: 9 additions & 3 deletions orttraining/orttraining/test/gradient/gradient_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ struct TensorInfo {
TensorInfo(const std::initializer_list<int64_t>& shape,
bool has_gradient = true,
std::function<float(float)>* transformer = nullptr,
MLDataType data_type = DataTypeImpl::GetTensorType<float>())
: shape(shape), has_gradient(has_gradient), transformer(transformer), data_type(data_type) {}
MLDataType data_type = DataTypeImpl::GetTensorType<float>(),
const std::vector<std::string>& dim_params = std::vector<std::string>{})
: shape(shape),
has_gradient(has_gradient),
transformer(transformer),
data_type(data_type),
dim_params(dim_params) {}

TensorInfo(const TensorShape& shape,
bool has_gradient = true,
Expand All @@ -39,6 +44,7 @@ struct TensorInfo {
bool has_gradient;
std::function<float(float)>* transformer;
MLDataType data_type;
std::vector<std::string> dim_params;
};

// TODO: This class currently assumes the inputs share types and the outputs share a type.
Expand Down Expand Up @@ -85,7 +91,7 @@ class GradientChecker {
const std::vector<TensorInfo>& y_infos,
std::vector<std::vector<JAC_T>>* jacobians);

std::vector<OrtValue> EvaluateFunctionAtInput(OpTester& op_tester,
std::vector<OrtValue> EvaluateFunctionAtInput(OpTester& op_tester,
const std::vector<TensorInfo>& x_infos,
const std::vector<TensorInfo>& y_infos,
std::vector<std::vector<X_T>>* x_datas,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void GradientOpTester::FillFeedsAndOutputNames(std::unordered_map<std::string, M
}
auto shape = output_data_[i].data_.Get<Tensor>().Shape();
std::vector<float> values(shape.Size(), 0.0);
if (output_index_to_use_as_loss == i) {
if (output_index_to_use_as_loss == static_cast<int>(i)) {
values[data_index_of_output] = 1.0; //set only one value to one to construct jacobian matrix
}
AddData<float>(gradient_data, (output_data_[i].def_.Name() + "_grad").c_str(), shape.GetDims(), values.data(), values.size(), true);
Expand Down
21 changes: 21 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,27 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
EXPECT_IS_TINY(max_error);
}

// symbolic broadcast
// shape(A) = (4, 2, 1, "seq(3)"), shape(B) = (4, 2, 1, 1), ==> shape(result) = (4, 2, 1, 3)
{
TensorInfo A_info{{4, 2, 1, 3}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"4", "2", "1", "seq"}};
TensorInfo B_info{{4, 2, 1, 1}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"4", "2", "1", "1"}};
TensorInfo Y_info{{4, 2, 1, 3}};

gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
EXPECT_IS_TINY(max_error);
}
// symbolic broadcast + numeric broadcast
// shape(A) = ("batch(4)", 2, "seq(3)", "seq(3)"), shape(B) = ("batch(4)", 1, "seq(3)", "seq(3)"), ==> shape(result) = (4, 2, 3, 3)
{
TensorInfo A_info{{4, 2, 3, 3}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"batch", "2", "seq", "seq"}};
TensorInfo B_info{{4, 1, 1, 3}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"batch", "1", "1", "seq"}};
TensorInfo Y_info{{4, 2, 3, 3}};

gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
EXPECT_IS_TINY(max_error);
}
}

TEST(GradientCheckerTest, AddGrad) {
Expand Down

0 comments on commit 0e59668

Please sign in to comment.