Skip to content

Commit

Permalink
feat: enhance aimrt_py rpc context handling and method overloads (Aim…
Browse files Browse the repository at this point in the history
…RT#48)

* feat: add new methods to RpcContext and RpcContextRef

Enhance functionality by introducing CheckUsed, SetUsed, Reset, GetFunctionName, and SetFunctionName methods for better state management and function identification.

* feat: add RPC service details to ServiceBase class

Enhance the ServiceBase class by adding methods to retrieve RPC type and service name, along with the ability to set the service name. This improves accessibility and flexibility for RPC configurations.

* feat: simplify service function type definition

Introduce type aliases for service function return and parameter types to enhance code readability and maintainability. This change reduces redundancy and clarifies the expected function signatures, streamlining future development.

* feat: enhance RPC framework with proxy support

Add support for `ProxyBase` in the RPC framework, enabling more flexible service management and context handling in Python. Update the `RpcContext` definition to use shared pointers for better memory management.

* feat: enhance rpc context handling and method overloads

Improve the handling of RPC context by adding overloads for method arguments, ensuring type safety and clarity in usage. This change simplifies the implementation of service proxies and makes it easier to work with different context types.

* feat: enhance context handling in RPC proxy

Add default context reference to the `NewContextSharedPtr` method, simplifying context management in RPC calls for improved usability.
  • Loading branch information
zhangyi1357 authored Oct 24, 2024
1 parent e2e7706 commit cb01a34
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 16 deletions.
35 changes: 33 additions & 2 deletions src/runtime/python_runtime/export_rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ inline void ExportRpcContext(const pybind11::object& m) {
using aimrt::rpc::Context;
using aimrt::rpc::ContextRef;

pybind11::class_<Context>(m, "RpcContext")
pybind11::class_<Context, std::shared_ptr<Context>>(m, "RpcContext")
.def(pybind11::init<>())
.def("CheckUsed", &Context::CheckUsed)
.def("SetUsed", &Context::SetUsed)
.def("Reset", &Context::Reset)
.def("GetType", &Context::GetType)
.def("Timeout", &Context::Timeout)
.def("SetTimeout", &Context::SetTimeout)
Expand All @@ -70,6 +73,8 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def("SetToAddr", &Context::SetToAddr)
.def("GetSerializationType", &Context::GetSerializationType)
.def("SetSerializationType", &Context::SetSerializationType)
.def("GetFunctionName", &Context::GetFunctionName)
.def("SetFunctionName", &Context::SetFunctionName)
.def("ToString", &Context::ToString);

pybind11::class_<ContextRef>(m, "RpcContextRef")
Expand All @@ -78,6 +83,8 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def(pybind11::init<Context*>())
.def(pybind11::init<const std::shared_ptr<Context>&>())
.def("__bool__", &ContextRef::operator bool)
.def("CheckUsed", &ContextRef::CheckUsed)
.def("SetUsed", &ContextRef::SetUsed)
.def("GetType", &ContextRef::GetType)
.def("Timeout", &ContextRef::Timeout)
.def("SetTimeout", &ContextRef::SetTimeout)
Expand All @@ -88,15 +95,20 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def("SetToAddr", &ContextRef::SetToAddr)
.def("GetSerializationType", &ContextRef::GetSerializationType)
.def("SetSerializationType", &ContextRef::SetSerializationType)
.def("GetFunctionName", &ContextRef::GetFunctionName)
.def("SetFunctionName", &ContextRef::SetFunctionName)
.def("ToString", &ContextRef::ToString);
}

using ServiceFuncReturnType = std::tuple<aimrt::rpc::Status, std::string>;
using ServiceFuncType = std::function<ServiceFuncReturnType(aimrt::rpc::ContextRef, const pybind11::bytes&)>;

inline void PyRpcServiceBaseRegisterServiceFunc(
aimrt::rpc::ServiceBase& service,
std::string_view func_name,
const std::shared_ptr<const PyTypeSupport>& req_type_support,
const std::shared_ptr<const PyTypeSupport>& rsp_type_support,
std::function<std::tuple<aimrt::rpc::Status, std::string>(aimrt::rpc::ContextRef, const pybind11::bytes&)>&& service_func) {
ServiceFuncType&& service_func) {
static std::vector<std::shared_ptr<const PyTypeSupport>> py_ts_vec;
py_ts_vec.emplace_back(req_type_support);
py_ts_vec.emplace_back(rsp_type_support);
Expand Down Expand Up @@ -142,6 +154,9 @@ inline void ExportRpcServiceBase(pybind11::object m) {

pybind11::class_<ServiceBase>(std::move(m), "ServiceBase")
.def(pybind11::init<std::string_view, std::string_view>())
.def("RpcType", &ServiceBase::RpcType)
.def("ServiceName", &ServiceBase::ServiceName)
.def("SetServiceName", &ServiceBase::SetServiceName)
.def("RegisterServiceFunc", &PyRpcServiceBaseRegisterServiceFunc);
}

Expand Down Expand Up @@ -204,4 +219,20 @@ inline void ExportRpcHandleRef(pybind11::object m) {
.def("RegisterClientFunc", &PyRpcHandleRefRegisterClientFunc)
.def("Invoke", &PyRpcHandleRefInvoke);
}

inline void ExportRpcProxyBase(pybind11::object m) {
using aimrt::rpc::ContextRef;
using aimrt::rpc::ProxyBase;
using aimrt::rpc::RpcHandleRef;

pybind11::class_<ProxyBase>(std::move(m), "ProxyBase")
.def(pybind11::init<RpcHandleRef, std::string_view, std::string_view>())
.def("RpcType", &ProxyBase::RpcType)
.def("ServiceName", &ProxyBase::ServiceName)
.def("SetServiceName", &ProxyBase::SetServiceName)
.def("NewContextSharedPtr", &ProxyBase::NewContextSharedPtr, pybind11::arg("ctx_ref") = ContextRef())
.def("GetDefaultContextSharedPtr", &ProxyBase::GetDefaultContextSharedPtr)
.def("SetDefaultContextSharedPtr", &ProxyBase::SetDefaultContextSharedPtr);
}

} // namespace aimrt::runtime::python_runtime
1 change: 1 addition & 0 deletions src/runtime/python_runtime/python_runtime_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PYBIND11_MODULE(aimrt_python_runtime, m) {
ExportRpcContext(m);
ExportRpcServiceBase(m);
ExportRpcHandleRef(m);
ExportRpcProxyBase(m);

// parameter
ExportParameter(m);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import sys

from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest as CodeGeneratorRequest
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorResponse as CodeGeneratorResponse
from google.protobuf.descriptor_pb2 import FileDescriptorProto
from google.protobuf.compiler.plugin_pb2 import \
CodeGeneratorRequest as CodeGeneratorRequest
from google.protobuf.compiler.plugin_pb2 import \
CodeGeneratorResponse as CodeGeneratorResponse
from google.protobuf.descriptor_pb2 import \
FileDescriptorProto as FileDescriptorProto


class AimRTCodeGenerator:
t_pyfile: str = r"""# This file was generated by protoc-gen-aimrt_rpc which is a self-defined pb compiler plugin, do not edit it!!!
from typing import overload
import aimrt_py
import google.protobuf
import {{py_package_name}}
Expand Down Expand Up @@ -79,20 +84,46 @@ def {{rpc_func_name}}(self, ctx_ref, req):
{{service end}}
{{for service begin}}
class {{service_name}}Proxy:
class {{service_name}}Proxy(aimrt_py.ProxyBase):
def __init__(self, rpc_handle_ref=aimrt_py.RpcHandleRef()):
super().__init__(rpc_handle_ref, "pb", "{{package_name}}.{{service_name}}")
self.rpc_handle_ref = rpc_handle_ref
{{for method begin}}
def {{rpc_func_name}}(self, ctx, req):
if(type(ctx) == aimrt_py.RpcContext):
@overload
def {{rpc_func_name}}(
self, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
@overload
def {{rpc_func_name}}(
self, ctx_ref: aimrt_py.RpcContext, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
@overload
def {{rpc_func_name}}(
self, ctx_ref: aimrt_py.RpcContextRef, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
def {{rpc_func_name}}(self, *args):
if len(args) == 1:
ctx = super().NewContextSharedPtr()
req = args[0]
elif len(args) == 2:
ctx = args[0]
req = args[1]
else:
raise TypeError(f"{{rpc_func_name}} expects 1 or 2 arguments, got {len(args)}")
if isinstance(ctx, aimrt_py.RpcContext):
ctx_ref = aimrt_py.RpcContextRef(ctx)
elif(type(ctx) == aimrt_py.RpcContextRef):
elif isinstance(ctx, aimrt_py.RpcContextRef):
ctx_ref = ctx
else:
raise TypeError("ctx must be 'aimrt_py.RpcContext' or 'aimrt_py.RpcContextRef'")
raise TypeError(f"ctx must be 'aimrt_py.RpcContext' or 'aimrt_py.RpcContextRef', got {type(ctx)}")
if(ctx_ref):
if(ctx_ref.GetSerializationType() == ""):
if ctx_ref:
if ctx_ref.GetSerializationType() == "":
ctx_ref.SetSerializationType("pb")
else:
real_ctx = aimrt_py.RpcContext()
Expand All @@ -105,9 +136,9 @@ def {{rpc_func_name}}(self, ctx, req):
try:
req_str = ""
if(serialization_type == "pb"):
if serialization_type == "pb":
req_str = req.SerializeToString()
elif(serialization_type == "json"):
elif serialization_type == "json":
req_str = google.protobuf.json_format.MessageToJson(req)
else:
return (aimrt_py.RpcStatus(aimrt_py.RpcStatusRetCode.CLI_INVALID_SERIALIZATION_TYPE), rsp)
Expand All @@ -118,9 +149,9 @@ def {{rpc_func_name}}(self, ctx, req):
ctx_ref, req_str)
try:
if(serialization_type == "pb"):
if serialization_type == "pb":
rsp.ParseFromString(rsp_str)
elif(serialization_type == "json"):
elif serialization_type == "json":
google.protobuf.json_format.Parse(rsp_str, rsp)
else:
return (aimrt_py.RpcStatus(aimrt_py.RpcStatusRetCode.CLI_INVALID_SERIALIZATION_TYPE), rsp)
Expand Down

0 comments on commit cb01a34

Please sign in to comment.