Skip to content

Commit

Permalink
Add dynamic shape support to AOT driver & compiler (#72995)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #72995

Add ability to specify input dimensions that need to be dynamic.
Example: if dim 115 can be dynamic in input sizes "1,115;1", then specify dynamic_dims as "115"

Also recompile and update CI models and some asm code as the old ones don't compile with compiler changes in context.cpp

Test Plan: - Compiles and runs BI Bytedoc model with and without dynamic inputs.

Reviewed By: ZolotukhinM

Differential Revision: D34233121

fbshipit-source-id: 35095e549ebd6d3bec98b9abb3f0764366a0ff6f
(cherry picked from commit 33166a9)
  • Loading branch information
priyaramani authored and pytorchmergebot committed Feb 24, 2022
1 parent 5a7778c commit ac97e95
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 25 deletions.
5 changes: 5 additions & 0 deletions binaries/aot_model_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ C10_DEFINE_string(
"Input memory format."
"If multiple inputs needed, use semicolon to separate."
"Supported values: contiguous, channels_last");
C10_DEFINE_string(
dynamic_dims,
"",
"Comma separated dimensions of input tensors that can be dynamic");
C10_DEFINE_string(method_name, "forward", "The name of the method.");
C10_DEFINE_string(
output_llvm,
Expand Down Expand Up @@ -68,6 +72,7 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
method_spec.insert("sizes", FLAGS_input_dims);
method_spec.insert("types", FLAGS_input_types);
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
method_spec.insert("asmfile", FLAGS_output_llvm);
method_spec.insert("model_name", FLAGS_model_name);
method_spec.insert("model_version", FLAGS_model_version);
Expand Down
10 changes: 7 additions & 3 deletions test/mobile/nnc/test_nnc_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
const std::string& method_name,
const std::string& model_name,
const std::string& input_shapes,
const std::string& input_types) {
const std::string& input_types,
const std::string& memory_formats,
const std::string& dynamic_sizes) {
c10::Dict<c10::IValue, c10::IValue> method_spec(
c10::StringType::get(), c10::AnyType::get());

Expand All @@ -33,6 +35,8 @@ c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
method_spec.insert("model_version", "v1");
method_spec.insert("asmfile", "fake_nnc_model.s");
method_spec.insert("arch", "x86-64");
method_spec.insert("memory_formats", memory_formats);
method_spec.insert("dynamic_sizes", dynamic_sizes);

c10::Dict<c10::IValue, c10::IValue> compile_spec(
c10::StringType::get(), c10::AnyType::get());
Expand Down Expand Up @@ -63,7 +67,7 @@ REGISTER_NNC_KERNEL(

TEST(NNCBackendTest, AOTCompileThenExecute) {
torch::jit::Module m("m");
auto param = torch::ones({});
auto param = torch::ones({1});
m.register_parameter("param", param, false);
m.define(R"(
def forward(self, input):
Expand All @@ -77,7 +81,7 @@ TEST(NNCBackendTest, AOTCompileThenExecute) {

// Compile the model with NNC.
auto compile_spec = create_compile_spec(
"forward", "_add_kernel_nnc_fake_model", "4,4", "float");
"forward", "_add_kernel_nnc_fake_model", "4,4", "float", "", "");
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
auto frozen_m = torch::jit::freeze_module(m.clone());
Expand Down
83 changes: 67 additions & 16 deletions torch/csrc/jit/mobile/nnc/aot_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {

// Construct input-specs vector from the inputs of the original graph
std::vector<mobile::nnc::InputSpec> toInputSpecs(
const std::shared_ptr<Graph>& g) {
const std::shared_ptr<tensorexpr::TensorExprKernel>& kernel) {
const std::shared_ptr<Graph>& g = kernel->graph();
std::vector<mobile::nnc::InputSpec> specs;
for (auto v : g->inputs()) {

// Graph inputs include scalar values for symbolic shapes, for which we
// don't need input specs. These scalar values come last among the graph
// inputs
auto num_inputs =
g->inputs().size() - kernel->getSymbolicShapeInputs().size();

for (int i = 0; i < num_inputs; i++) {
auto v = g->inputs()[i];
const auto& t = v->type();
mobile::nnc::InputSpec spec;
TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
Expand Down Expand Up @@ -120,7 +129,7 @@ std::unique_ptr<Function> compileMethod(
const std::vector<at::ScalarType>& types) {
auto func = std::make_unique<Function>();
func->set_name(method_name);
func->set_input_specs(toInputSpecs(kernel->graph()));
func->set_input_specs(toInputSpecs(kernel));

auto params = c10::impl::GenericList(c10::AnyType::get());
auto const_descriptors = kernel->getConstantDescriptors();
Expand Down Expand Up @@ -177,18 +186,33 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
std::shared_ptr<Graph>& g,
const std::vector<std::vector<int64_t>>& sizes,
const std::vector<at::ScalarType>& types,
const std::string& kernel_func_name) {
const std::string& kernel_func_name,
const std::vector<int64_t>& symbolic_ind) {
GRAPH_DEBUG("Input sizes ", sizes);
GRAPH_DEBUG("Input types ", types);
GRAPH_DEBUG("Method name ", method_name);
GRAPH_DEBUG("Kernel func name ", kernel_func_name);

std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(
TensorExprKernel(g, kernel_func_name));
GRAPH_DEBUG("Symbolic indices ", symbolic_ind);

std::shared_ptr<tensorexpr::TensorExprKernel> kernel;
std::vector<torch::jit::StrideInput> stride_desc = {
torch::jit::StrideInput::TENSOR_CONT};
std::unordered_map<
const torch::jit::Value*,
std::vector<torch::jit::StrideInput>>
symbolic_strides;
if (!symbolic_ind.empty()) {
for (auto i : g->inputs()) {
symbolic_strides[i] = stride_desc;
}
for (auto o : g->outputs()) {
symbolic_strides[o] = stride_desc;
}
}
kernel = std::make_shared<tensorexpr::TensorExprKernel>(TensorExprKernel(
g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides));

const std::string compiled_assembly = kernel->getCodeText();

auto func = compileMethod(kernel, method_name, sizes, types);
return std::make_pair(std::move(func), compiled_assembly);
}
Expand Down Expand Up @@ -271,6 +295,17 @@ std::vector<at::MemoryFormat> parseInputMemoryFormats(
return memFormats;
}

std::vector<int64_t> parseInputDynamicShapes(
const std::string& dynamic_dims_s) {
std::vector<std::string> dynamic_dims_list = split(',', dynamic_dims_s);
std::vector<int64_t> dynamic_dims;
dynamic_dims.reserve(dynamic_dims_list.size());
for (const auto& dim : dynamic_dims_list) {
dynamic_dims.push_back(c10::stoi(dim));
}
return dynamic_dims;
}

std::string getNncKernelId(
const std::string& model_name,
const std::string& model_version,
Expand All @@ -288,9 +323,12 @@ std::string getNncKernelFuncName(
return "nnc_" + model_name + "_" + model_version + "_" + method_name;
}

std::shared_ptr<Graph> preprocessGraphPasses(
// Preprocess the graph and returns the processed graph and
// symbolic values if dynamic input shapes are specified
std::pair<std::shared_ptr<Graph>, std::vector<int64_t>> preprocessGraphPasses(
std::shared_ptr<Graph>& graph,
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
const std::vector<c10::optional<at::Tensor>>& example_inputs,
const std::vector<int64_t>& dynamic_sizes) {
GRAPH_DEBUG("Before preprocessing graph passes: ", *graph);
torch::jit::RemoveTensorMutation(graph);
torch::jit::EliminateDeadCode(graph->block());
Expand Down Expand Up @@ -321,8 +359,12 @@ std::shared_ptr<Graph> preprocessGraphPasses(
RemoveTensorMutation(graph);
EliminateDeadCode(graph);
LowerAllTuples(graph);

auto sym_val =
torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes);

GRAPH_DEBUG("After preprocessing graph passes: ", *graph);
return graph;
return std::make_pair(graph, sym_val);
}

std::vector<c10::optional<at::Tensor>> generateExampleInputs(
Expand All @@ -335,8 +377,7 @@ std::vector<c10::optional<at::Tensor>> generateExampleInputs(
const auto dtype = at::dtype(inputTypes[i]);
const auto memory_format = inputMemoryFormats[i];
example_inputs.emplace_back(
at::rand(inputShapes[i], at::TensorOptions(dtype))
.contiguous(memory_format));
at::rand(inputShapes[i]).to(dtype).contiguous(memory_format));
}
return example_inputs;
}
Expand Down Expand Up @@ -364,6 +405,8 @@ c10::IValue preprocess(

auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
auto types = parseInputTypes(*method_spec.at("types").toString());
auto dynamic_sizes =
parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString());

std::string memory_formats_str = method_spec.contains("memory_formats")
? (*method_spec.at("memory_formats").toString()).string()
Expand All @@ -374,12 +417,20 @@ c10::IValue preprocess(
: parseInputMemoryFormats(memory_formats_str);

auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
graph = preprocessGraphPasses(graph, example_inputs);
auto preprocessed =
preprocessGraphPasses(graph, example_inputs, dynamic_sizes);

auto kernel_func_name =
getNncKernelFuncName(model_name, model_version, method_name);
auto processed_graph = preprocessed.first;
auto sym_values = preprocessed.second;
auto compiled = torch::jit::mobile::nnc::aotCompile(
method_name, graph, sizes, types, kernel_func_name);
method_name,
processed_graph,
sizes,
types,
kernel_func_name,
sym_values);
writeOutputLlvmAssembly(compiled.second, asmfile_name);
auto func = std::move(compiled.first);
func->set_nnc_kernel_id(
Expand Down
32 changes: 26 additions & 6 deletions torch/csrc/jit/mobile/nnc/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,17 @@ c10::IValue InputSpec::serialize() const {
}

bool InputSpec::validate(const at::Tensor& input) const {
return input.sizes() == sizes_ && input.scalar_type() == dtype_;
if (sizes_.size() != input.sizes().size() || input.scalar_type() != dtype_) {
return false;
}
auto spec_sizes = sizes_;
for (int i = 0; i < spec_sizes.size(); i++) {
// InputSpec size 0 means that the dimension is dynamic
if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
return false;
}
}
return true;
}

OutputSpec::OutputSpec(const c10::IValue& value) {
Expand Down Expand Up @@ -136,6 +146,14 @@ Function::Function(const c10::IValue& value) {

// memory_plan_
memory_plan_ = MemoryPlan(dict.at("memory_plan"));

// symbolic shape positions
for (const auto& sym_shape_pos :
dict.at("sym_shape_pos").toTupleRef().elements()) {
auto sym_shape_elements = sym_shape_pos.toTupleRef().elements();
sym_shape_positions_.emplace_back(
sym_shape_elements[0].toInt(), sym_shape_elements[1].toInt());
}
}

c10::IValue Function::serialize() const {
Expand Down Expand Up @@ -185,18 +203,20 @@ void Function::init_execution_state() const {
ExecutionState state;
memory_plan_.allocate(&state);

// The arguments vector consists of 4 sections: inputs, outputs, parameters
// and buffers.
// The arguments vector consists of 5 sections: inputs, symbolic shapes,
// outputs, parameters and buffers.
auto input_args = input_specs_.size();
auto sym_shape_args = sym_shape_positions_.size();
auto output_args = output_specs_.size();
auto param_args = parameters_.size();
auto buffer_args = state.preallocations_.size();

auto& arguments = state.arguments_;
arguments.reserve(input_args + output_args + param_args + buffer_args);
arguments.reserve(
input_args + sym_shape_args + output_args + param_args + buffer_args);

// Keep empty slots to fill in inputs/outputs pointers at execution time.
arguments.resize(input_args + output_args);
arguments.resize(input_args + sym_shape_args + output_args);

// Fill in parameters as untyped raw pointers.
// The underlying storage of the parameters should be owned by `parameters_`,
Expand Down Expand Up @@ -233,7 +253,7 @@ c10::impl::GenericList Function::run(

// Fill in input tensors.
TORCH_CHECK(
input_specs_.size() == (inputs.size() + sym_shape_positions_.size()),
input_specs_.size() == inputs.size(),
"Input size doesn't match the spec, expect: ",
input_specs_.size(),
" actual: ",
Expand Down

0 comments on commit ac97e95

Please sign in to comment.