feat: enhance aimrt_py rpc context handling and method overloads (#48)

* feat: add new methods to RpcContext and RpcContextRef

Enhance functionality by introducing CheckUsed, SetUsed, Reset, GetFunctionName, and SetFunctionName methods for better state management and function identification.

* feat: add RPC service details to ServiceBase class

Enhance the ServiceBase class by adding methods to retrieve RPC type and service name, along with the ability to set the service name. This improves accessibility and flexibility for RPC configurations.

* feat: simplify service function type definition

Introduce type aliases for service function return and parameter types to enhance code readability and maintainability. This change reduces redundancy and clarifies the expected function signatures, streamlining future development.

* feat: enhance RPC framework with proxy support

Add support for `ProxyBase` in the RPC framework, enabling more flexible service management and context handling in Python. Update the `RpcContext` definition to use shared pointers for better memory management.

* feat: enhance rpc context handling and method overloads

Improve the handling of RPC context by adding overloads for method arguments, ensuring type safety and clarity in usage. This change simplifies the implementation of service proxies and makes it easier to work with different context types.

* feat: enhance context handling in RPC proxy

Add default context reference to the `NewContextSharedPtr` method, simplifying context management in RPC calls for improved usability.
This commit is contained in:
zhangyi1357 2024-10-24 15:49:40 +08:00 committed by GitHub
parent e2e77060af
commit cb01a34047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 16 deletions

View File

@ -58,8 +58,11 @@ inline void ExportRpcContext(const pybind11::object& m) {
using aimrt::rpc::Context;
using aimrt::rpc::ContextRef;
pybind11::class_<Context>(m, "RpcContext")
pybind11::class_<Context, std::shared_ptr<Context>>(m, "RpcContext")
.def(pybind11::init<>())
.def("CheckUsed", &Context::CheckUsed)
.def("SetUsed", &Context::SetUsed)
.def("Reset", &Context::Reset)
.def("GetType", &Context::GetType)
.def("Timeout", &Context::Timeout)
.def("SetTimeout", &Context::SetTimeout)
@ -70,6 +73,8 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def("SetToAddr", &Context::SetToAddr)
.def("GetSerializationType", &Context::GetSerializationType)
.def("SetSerializationType", &Context::SetSerializationType)
.def("GetFunctionName", &Context::GetFunctionName)
.def("SetFunctionName", &Context::SetFunctionName)
.def("ToString", &Context::ToString);
pybind11::class_<ContextRef>(m, "RpcContextRef")
@ -78,6 +83,8 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def(pybind11::init<Context*>())
.def(pybind11::init<const std::shared_ptr<Context>&>())
.def("__bool__", &ContextRef::operator bool)
.def("CheckUsed", &ContextRef::CheckUsed)
.def("SetUsed", &ContextRef::SetUsed)
.def("GetType", &ContextRef::GetType)
.def("Timeout", &ContextRef::Timeout)
.def("SetTimeout", &ContextRef::SetTimeout)
@ -88,15 +95,20 @@ inline void ExportRpcContext(const pybind11::object& m) {
.def("SetToAddr", &ContextRef::SetToAddr)
.def("GetSerializationType", &ContextRef::GetSerializationType)
.def("SetSerializationType", &ContextRef::SetSerializationType)
.def("GetFunctionName", &ContextRef::GetFunctionName)
.def("SetFunctionName", &ContextRef::SetFunctionName)
.def("ToString", &ContextRef::ToString);
}
using ServiceFuncReturnType = std::tuple<aimrt::rpc::Status, std::string>;
using ServiceFuncType = std::function<ServiceFuncReturnType(aimrt::rpc::ContextRef, const pybind11::bytes&)>;
inline void PyRpcServiceBaseRegisterServiceFunc(
aimrt::rpc::ServiceBase& service,
std::string_view func_name,
const std::shared_ptr<const PyTypeSupport>& req_type_support,
const std::shared_ptr<const PyTypeSupport>& rsp_type_support,
std::function<std::tuple<aimrt::rpc::Status, std::string>(aimrt::rpc::ContextRef, const pybind11::bytes&)>&& service_func) {
ServiceFuncType&& service_func) {
static std::vector<std::shared_ptr<const PyTypeSupport>> py_ts_vec;
py_ts_vec.emplace_back(req_type_support);
py_ts_vec.emplace_back(rsp_type_support);
@ -142,6 +154,9 @@ inline void ExportRpcServiceBase(pybind11::object m) {
pybind11::class_<ServiceBase>(std::move(m), "ServiceBase")
.def(pybind11::init<std::string_view, std::string_view>())
.def("RpcType", &ServiceBase::RpcType)
.def("ServiceName", &ServiceBase::ServiceName)
.def("SetServiceName", &ServiceBase::SetServiceName)
.def("RegisterServiceFunc", &PyRpcServiceBaseRegisterServiceFunc);
}
@ -204,4 +219,20 @@ inline void ExportRpcHandleRef(pybind11::object m) {
.def("RegisterClientFunc", &PyRpcHandleRefRegisterClientFunc)
.def("Invoke", &PyRpcHandleRefInvoke);
}
inline void ExportRpcProxyBase(pybind11::object m) {
using aimrt::rpc::ContextRef;
using aimrt::rpc::ProxyBase;
using aimrt::rpc::RpcHandleRef;
pybind11::class_<ProxyBase>(std::move(m), "ProxyBase")
.def(pybind11::init<RpcHandleRef, std::string_view, std::string_view>())
.def("RpcType", &ProxyBase::RpcType)
.def("ServiceName", &ProxyBase::ServiceName)
.def("SetServiceName", &ProxyBase::SetServiceName)
.def("NewContextSharedPtr", &ProxyBase::NewContextSharedPtr, pybind11::arg("ctx_ref") = ContextRef())
.def("GetDefaultContextSharedPtr", &ProxyBase::GetDefaultContextSharedPtr)
.def("SetDefaultContextSharedPtr", &ProxyBase::SetDefaultContextSharedPtr);
}
} // namespace aimrt::runtime::python_runtime

View File

@ -54,6 +54,7 @@ PYBIND11_MODULE(aimrt_python_runtime, m) {
ExportRpcContext(m);
ExportRpcServiceBase(m);
ExportRpcHandleRef(m);
ExportRpcProxyBase(m);
// parameter
ExportParameter(m);

View File

@ -7,14 +7,19 @@
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest as CodeGeneratorRequest
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorResponse as CodeGeneratorResponse
from google.protobuf.descriptor_pb2 import FileDescriptorProto
from google.protobuf.compiler.plugin_pb2 import \
CodeGeneratorRequest as CodeGeneratorRequest
from google.protobuf.compiler.plugin_pb2 import \
CodeGeneratorResponse as CodeGeneratorResponse
from google.protobuf.descriptor_pb2 import \
FileDescriptorProto as FileDescriptorProto
class AimRTCodeGenerator:
t_pyfile: str = r"""# This file was generated by protoc-gen-aimrt_rpc which is a self-defined pb compiler plugin, do not edit it!!!
from typing import overload
import aimrt_py
import google.protobuf
import {{py_package_name}}
@ -79,20 +84,46 @@ class {{service_name}}(aimrt_py.ServiceBase):
{{service end}}
{{for service begin}}
class {{service_name}}Proxy:
class {{service_name}}Proxy(aimrt_py.ProxyBase):
def __init__(self, rpc_handle_ref=aimrt_py.RpcHandleRef()):
super().__init__(rpc_handle_ref, "pb", "{{package_name}}.{{service_name}}")
self.rpc_handle_ref = rpc_handle_ref
{{for method begin}}
def {{rpc_func_name}}(self, ctx, req):
if(type(ctx) == aimrt_py.RpcContext):
@overload
def {{rpc_func_name}}(
self, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
@overload
def {{rpc_func_name}}(
self, ctx_ref: aimrt_py.RpcContext, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
@overload
def {{rpc_func_name}}(
self, ctx_ref: aimrt_py.RpcContextRef, req: {{full_rpc_req_py_name}}
) -> tuple[aimrt_py.RpcStatus, {{full_rpc_rsp_py_name}}]: ...
def {{rpc_func_name}}(self, *args):
if len(args) == 1:
ctx = super().NewContextSharedPtr()
req = args[0]
elif len(args) == 2:
ctx = args[0]
req = args[1]
else:
raise TypeError(f"{{rpc_func_name}} expects 1 or 2 arguments, got {len(args)}")
if isinstance(ctx, aimrt_py.RpcContext):
ctx_ref = aimrt_py.RpcContextRef(ctx)
elif(type(ctx) == aimrt_py.RpcContextRef):
elif isinstance(ctx, aimrt_py.RpcContextRef):
ctx_ref = ctx
else:
raise TypeError("ctx must be 'aimrt_py.RpcContext' or 'aimrt_py.RpcContextRef'")
raise TypeError(f"ctx must be 'aimrt_py.RpcContext' or 'aimrt_py.RpcContextRef', got {type(ctx)}")
if(ctx_ref):
if(ctx_ref.GetSerializationType() == ""):
if ctx_ref:
if ctx_ref.GetSerializationType() == "":
ctx_ref.SetSerializationType("pb")
else:
real_ctx = aimrt_py.RpcContext()
@ -105,9 +136,9 @@ class {{service_name}}Proxy:
try:
req_str = ""
if(serialization_type == "pb"):
if serialization_type == "pb":
req_str = req.SerializeToString()
elif(serialization_type == "json"):
elif serialization_type == "json":
req_str = google.protobuf.json_format.MessageToJson(req)
else:
return (aimrt_py.RpcStatus(aimrt_py.RpcStatusRetCode.CLI_INVALID_SERIALIZATION_TYPE), rsp)
@ -118,9 +149,9 @@ class {{service_name}}Proxy:
ctx_ref, req_str)
try:
if(serialization_type == "pb"):
if serialization_type == "pb":
rsp.ParseFromString(rsp_str)
elif(serialization_type == "json"):
elif serialization_type == "json":
google.protobuf.json_format.Parse(rsp_str, rsp)
else:
return (aimrt_py.RpcStatus(aimrt_py.RpcStatusRetCode.CLI_INVALID_SERIALIZATION_TYPE), rsp)