-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Onnx runtime Server to only use public APIs #1271
Conversation
This is a big PR and hard to review because it seems to be an amalgamation of several things. It would be easier to review if you did this in 3 PRs:
I think it would also speed up the review process. |
onnxruntime/server/executor.cc
Outdated
auto logger = env_->GetLogger(request_id_); | ||
|
||
size_t cpu_tensor_length = 0; | ||
auto status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(input_tensor, &cpu_tensor_length); | ||
auto status = onnxruntime::server::GetSizeInBytesFromTensorProto<0>(input_tensor, &cpu_tensor_length); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is OrtGetTensorMemSizeInBytesFromTensorProto
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose not to use those for two reasons:
- Those work on serialized protobufs and for us to support JSON and GRPC we need to be able to control the deserialization.
- I got the impression from Changming that we should treat the TensorProto as internal to ORT and not depend specifically on it. @snnn ?
onnxruntime/server/executor.cc
Outdated
status = onnxruntime::utils::TensorProtoToMLValue(onnxruntime::Env::Default(), nullptr, input_tensor, | ||
onnxruntime::MemBuffer(buf, cpu_tensor_length, *cpu_allocator_info), | ||
ml_value, deleter); | ||
status = onnxruntime::server::TensorProtoToMLValue(input_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is OrtTensorProtoToOrtValue
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same reasoning as above.
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) { | ||
tensor_proto.add_int32_data(i32data[i]); | ||
for (size_t i = 0, count = elem_count; i < count; ++i) { | ||
tensor_proto.add_int32_data(reinterpret_cast<const uint16_t*>(data)[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean uint32_t
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no - float16 is supposed to get cast to uint16 - from onnx-ml.proto
// For int32, uint8, int8, uint16, int16, bool, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
repeated int32 int32_data = 5 [packed = true];
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I add a comment maybe?
if (!status.IsOK()) { | ||
logger->error("GetSizeInBytesFromTensorProto() failed. Error Message: {}", status.ToString()); | ||
return GenerateProtobufStatus(status, "GetSizeInBytesFromTensorProto() failed: " + status.ToString()); | ||
try { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we use the return value here? try-catch will hurt the performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He is using try/catch because GetSizeInBytesFromTensorProto doesn't return anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The c++ api uses exceptions, so I decided to use exceptions for the converter as well. IIRC c++ exceptions are generally zero cost in the happy path so this will only hurt performance when it fails.
/azp run |
Azure Pipelines successfully started running 22 pipeline(s). |
Description: Refactors ONNXRT server to only use public APIs. First step in switching to dynamic linking.
Motivation and Context
ONNXRT server must be statically linked with the ONNXRT. It also uses internal (non-stable) APIs, which can limit iteration speed. This PR moves to only using public APIs so that they can be more loosely coupled.