如何在 ONNXRuntime 中扩展【推理引擎】以支持自定义算子?

2026-05-27 13:151阅读0评论SEO资源
  • 内容介绍
  • 文章标签
  • 相关推荐

本文共计3079个文字,预计阅读时间需要13分钟。

如何在 ONNXRuntime 中扩展【推理引擎】以支持自定义算子?

在ONNX模型中,如果遇到一些算子不被ONNX算子库支持,我们可以利用ONNXRuntime提供的API手动添加新算子。官方文档中已对如何添加自定义算子进行了详细说明(详见:https://onnxruntime.ai/docs/reference/operators/)。以下是一个简化的步骤概述:

1. 定义算子:首先,需要定义一个新的算子类,继承自`onnxruntime::CustomOpBase`。

2.实现算子方法:在自定义算子类中,实现`Compute`方法,该方法将包含实际的计算逻辑。

3.注册算子:使用`onnxruntime::AddCustomOp`函数将自定义算子注册到ONNXRuntime中。

4.使用算子:在ONNX模型中添加使用该自定义算子的节点。

例如,假设我们要添加一个简单的自定义算子,计算两个数的和:

cpp

#include onnxruntime/core/custom_op/custom_op.h

using namespace onnxruntime;

class AddCustomOp : public CustomOpBase {public: AddCustomOp(const OpSchema& schema) : CustomOpBase(schema) {}

Status Compute(OpKernelContext* context) const override { // 获取输入 const Tensor& input1=*context->Input(0); const Tensor& input2=*context->Input(1);

// 创建输出 Tensor* output=nullptr; ORT_RETURN_IF_ERROR(context->Output(0, &output));

// 计算和 const float* input1_data=input1.Data(); const float* input2_data=input2.Data(); float* output_data=output->MutableData();

for (size_t i=0; i

return Status::OK(); }};

ONNX_REGISTER_CUSTOM_OP(AddCustomOp) .SetDoc(RDOC(Calculate the sum of two tensors.)DOC) .SetSchema(RDOC( (input1: float[*,*]) (input2: float[*,*]) (output: float[*,*]))DOC) .SetValidationSchema(RDOC( (input1: float[*,*]) (input2: float[*,*]) (output: float[*,*]))DOC) .SetProvider(com.example) .SetType(custom) .AddInput(input1, The first input tensor., float, 1, N) .AddInput(input2, The second input tensor., float, 1, N) .AddOutput(output, The output tensor., float, 1, N);

这样,你就可以在ONNX模型中使用`AddCustomOp`算子了。

如果模型中有些算子不被ONNX算子库支持,我们就需要利用ONNXRuntime提供的API手动添加新算子。在官方文档中已经对如何添加定制算子进行了介绍(onnxruntime.ai/docs/reference/operators/add-custom-op.html ),这里我们主要把源码中对应的流程给捋清楚。

在 ONNXRuntime 中添加算子共有两种方式:

  • 第一种方式是首先编译ONNXRuntime,然后利用其暴露出的 API 来添加新的定制算子,这也是本文的主要内容;
  • 第二种方式是在 Contrib 域中添加定制算子,但是添加完成后需要重新编译 ONNXRuntime,这种方式使得编译后的ONNXRuntime二进制库体积增大,本文没有针对这种方式深入介绍;

下面介绍第一种添加算子的方式。
添加定制算子(Custom Operators)主要分为三步:

  1. 创建一个定制算子域(CusttomOpDomain);
  2. 创建一个定制算子(CustomOp),并将该算子添加到定制算子域中;
  3. 将定制算子域添加到 SessionOption 中

首先看看源码中给出的定制算子样例:

// file path: onnxruntime/test/shared_lib/custom_op_utils.h // 首先定义定制算子的核 struct MyCustomKernel { MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/, void* compute_stream) : ort_(ort), compute_stream_(compute_stream) { } void Compute(OrtKernelContext* context); private: Ort::CustomOpApi ort_; void* compute_stream_; }; // 然后定义定制算子的各个操作,各个成员函数均已实现,其中 CreateKernel 会返回前面定义的算子核对象 struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> { explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info, compute_stream_); }; const char* GetName() const { return "Foo"; }; const char* GetExecutionProviderType() const { return provider_; }; size_t GetInputTypeCount() const { return 2; }; ONNXTensorElementDataType GetInputType(size_t /*index*/) const { // Both the inputs need to be necessarily of float type return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; size_t GetOutputTypeCount() const { return 1; }; ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; private: const char* provider_; void* compute_stream_; };

在上面代码中,我们看到定制算子继承自 Ort::CustomOpBase<MyCustomOp, MyCustomKernel>,这种扩展类作为模板基类的模板参数的方式又被称为CRTP,接着深入到这个模板类内部:

// file path: include/onnxruntime/core/session/onnxruntime_cxx_api.h template <typename TOp, typename TKernel> struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); }; OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); }; OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider const char* GetExecutionProviderType() const { return nullptr; } // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below // (inputs and outputs are required by default) OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } };

这里的 CustomOpBase 又继承自 OrtCustomOp

// include/onnxruntime/core/session/onnxruntime_c_api.h struct OrtCustomOp; typedef struct OrtCustomOp OrtCustomOp; struct OrtCustomOp { uint32_t version; // Must be initialized to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, _In_ const OrtKernelInfo* info); // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); // Returns the type of the execution provider, return nullptr to use CPU execution provider const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); // Returns the count and types of the input & output tensors ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); // Op kernel callbacks void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); // Returns the characteristics of the input & output tensors OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); };

可以发现,OrtCustomOp 中定义了定制算子应该实现的模式,其中的一系列回调函数由其派生类一一实现,比如上文提到的 CustomOpBase 在其构造函数中,以 lambda 函数的方式实现各个回调函数。

至此,我们已经完整地梳理了定义定制算子在源码内部是如何实现的,接下来介绍如何将定义好的定制算子使用起来。

从如下官方测试代码开始分析:

如何在 ONNXRuntime 中扩展【推理引擎】以支持自定义算子?

// file path: onnxruntime/test/shared_lib/test_inference.cc TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; std::vector<Input> inputs(1); Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; // prepare expected inputs and outputs std::vector<int64_t> expected_dims_y = {3, 2}; std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; // 创建定制算子(MyCustomOp) #ifdef USE_CUDA cudaStream_t compute_stream = nullptr; // 声明一个 cuda stream cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); // 创建一个 cuda stream MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider, compute_stream}; #else MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider, nullptr}; #endif // 创建定制算子域(CustomOpDomain) Ort::CustomOpDomain custom_op_domain(""); // 在定制算子域中添加定制算子 custom_op_domain.Add(&custom_op); // 进入 TestInference #ifdef USE_CUDA TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr, false, compute_stream); cudaStreamDestroy(compute_stream); #else TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr); #endif }

以上代码需要特别注意的是,需要根据宏(USE_CUDA)用来判断是否使用CUDA。如果使用 CUDA:

  • 当模型运行在GPU上,而插入的是 CPU 定制算子,那么 ONNXRuntime 会在 CPU 定制算子前后分别插入两个操作 MemcpyToHost、MemcpyFromHost,这两个操作负责内存拷贝,即首先从 Device 拷贝到 Host,再从 Host 拷贝到 Device;
  • 如果插入的是 GPU 定制算子,为了确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步,它们必须使用同一个 CUDA 计算流。具体细节在下一个代码继续分析。

这里创建 cuda stream 的方式是 cudaStreamCreateWithFlags,该函数和 cudaStreamCreate 不同,后者在多次调用时是串行方式执行,而前者可同步执行。如果将参数 cudaStreamNonBlocking 替换为 cudaStreamDefault,则 cudaStreamCreateWithFlags 的行为将和 cudaStreamCreate 相同。

无论是否使用CDUA,我们都需要创建定制算子(MyCustomOp)。

进入 TestInference 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc template <typename OutT> static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri, const std::vector<Input>& inputs, const char* output_name, const std::vector<int64_t>& expected_dims_y, const std::vector<OutT>& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const char* custom_op_library_filename, void** library_handle = nullptr, bool test_session_creation_only = false, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; if (provider_type == 1) { #ifdef USE_CUDA std::cout << "Running simple inference with cuda provider" << std::endl; auto cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream); session_options.AppendExecutionProvider_CUDA(cuda_options); #else ORT_UNUSED_PARAMETER(cuda_compute_stream); return; #endif } else if (provider_type == 2) { #ifdef USE_DNNL Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1)); std::cout << "Running simple inference with dnnl provider" << std::endl; #else return; #endif } else if (provider_type == 3) { #ifdef USE_NUPHAR Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, "")); std::cout << "Running simple inference with nuphar provider" << std::endl; #else return; #endif } else { std::cout << "Running simple inference with default provider" << std::endl; } if (custom_op_domain_ptr) { session_options.Add(custom_op_domain_ptr); } if (custom_op_library_filename) { Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options, custom_op_library_filename, library_handle)); } // if session creation passes, model loads fine Ort::Session session(env, model_uri.c_str(), session_options); // caller wants to test running the model (not just loading the model) if (!test_session_creation_only) { // Now run auto default_allocator = std::make_unique<MockedOrtAllocator>(); //without preallocated output tensor RunSession<OutT>(default_allocator.get(), session, inputs, output_name, expected_dims_y, expected_values_y, nullptr); //with preallocated output tensor Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); //test it twice for (int i = 0; i != 2; ++i) RunSession<OutT>(default_allocator.get(), session, inputs, output_name, expected_dims_y, expected_values_y, &value_y); } }

前文提到,如果对应EP是CUDA,需要确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步。为了实现这一目标,首先通过 CreateDefaultOrtCudaProviderOptionsWithCustomStream 函数将新创建的 CUDA 计算流以 OrtCudaProviderOptions 的形式传递给 SessionOptions:

OrtCUDAProviderOptions cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream); session_options.AppendExecutionProvider_CUDA(cuda_options)

之后,将定制算子域也添加到 SessionOptions 中:

if (custom_op_domain_ptr) { session_options.Add(custom_op_domain_ptr); }

至此,SessionOptions 已经构建完成,下面创建 Session 并通过 model_uri 加载模型:

Ort::Session session(env, model_uri.c_str(), session_options);

这里的(1)Ort::Session 是在 onnxruntime_cxx_api.h 文件中声明的类,(2)对应的构造函数在 onnxruntime_cxx_inline.h 中实现,(3)实现方式是进一步调用 onnxruntime_c_api.h 中定义的 API,该 API 也仅仅是声明,(4)最终对应的实现在 onnxruntime_c_api.cc 文件中:

// (1) include/onnxruntime/core/session/onnxruntime_cxx_api.h struct Session : Base<OrtSession> { explicit Session(std::nullptr_t) {} Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); } // (2) include/onnxruntime/core/session/onnxruntime_cxx_inline.h inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); } // (3) include/onnxruntime/core/session/onnxruntime_c_api.h ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); // (4) onnxruntime/core/session/onnxruntime_c_api.cc ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr<onnxruntime::InferenceSession> sess; OrtStatus* status = nullptr; *out = nullptr; ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); *out = reinterpret_cast<OrtSession*>(sess.release()); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { status = OrtApis::CreateStatus(ORT_FAIL, e.what()); }); } return status; API_IMPL_END }

可以发现,Ort::Session 内部还是调用了 onnxruntime::InferenceSession

扯远了,下面回归主题。

创建 Session 完成之后,便开始运行,进入 RunSession 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc template <typename OutT> void RunSession(OrtAllocator* allocator, Ort::Session& session_object, const std::vector<Input>& inputs, const char* output_name, const std::vector<int64_t>& dims_y, const std::vector<OutT>& values_y, Ort::Value* output_tensor) { // 构建模型输入 std::vector<Ort::Value> ort_inputs; std::vector<const char*> input_names; for (size_t i = 0; i < inputs.size(); i++) { input_names.emplace_back(inputs[i].name); ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); } // 运行 RUN std::vector<Ort::Value> ort_outputs; if (output_tensor) session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1); else { ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); ASSERT_EQ(ort_outputs.size(), 1u); output_tensor = &ort_outputs[0]; } auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); ASSERT_EQ(type_info.GetShape(), dims_y); size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), total_len); OutT* f = output_tensor->GetTensorMutableData<OutT>(); for (size_t i = 0; i != total_len; ++i) { ASSERT_EQ(values_y[i], f[i]); } }

这里使用了一些GTest中的断言来判定运行结果是否符合预期。

至此,我们已经完整地分析了定制算子从定义到使用的全部流程。

本文共计3079个文字,预计阅读时间需要13分钟。

如何在 ONNXRuntime 中扩展【推理引擎】以支持自定义算子?

在ONNX模型中,如果遇到一些算子不被ONNX算子库支持,我们可以利用ONNXRuntime提供的API手动添加新算子。官方文档中已对如何添加自定义算子进行了详细说明(详见:https://onnxruntime.ai/docs/reference/operators/)。以下是一个简化的步骤概述:

1. 定义算子:首先,需要定义一个新的算子类,继承自`onnxruntime::CustomOpBase`。

2.实现算子方法:在自定义算子类中,实现`Compute`方法,该方法将包含实际的计算逻辑。

3.注册算子:使用`onnxruntime::AddCustomOp`函数将自定义算子注册到ONNXRuntime中。

4.使用算子:在ONNX模型中添加使用该自定义算子的节点。

例如,假设我们要添加一个简单的自定义算子,计算两个数的和:

cpp

#include onnxruntime/core/custom_op/custom_op.h

using namespace onnxruntime;

class AddCustomOp : public CustomOpBase {public: AddCustomOp(const OpSchema& schema) : CustomOpBase(schema) {}

Status Compute(OpKernelContext* context) const override { // 获取输入 const Tensor& input1=*context->Input(0); const Tensor& input2=*context->Input(1);

// 创建输出 Tensor* output=nullptr; ORT_RETURN_IF_ERROR(context->Output(0, &output));

// 计算和 const float* input1_data=input1.Data(); const float* input2_data=input2.Data(); float* output_data=output->MutableData();

for (size_t i=0; i

return Status::OK(); }};

ONNX_REGISTER_CUSTOM_OP(AddCustomOp) .SetDoc(RDOC(Calculate the sum of two tensors.)DOC) .SetSchema(RDOC( (input1: float[*,*]) (input2: float[*,*]) (output: float[*,*]))DOC) .SetValidationSchema(RDOC( (input1: float[*,*]) (input2: float[*,*]) (output: float[*,*]))DOC) .SetProvider(com.example) .SetType(custom) .AddInput(input1, The first input tensor., float, 1, N) .AddInput(input2, The second input tensor., float, 1, N) .AddOutput(output, The output tensor., float, 1, N);

这样,你就可以在ONNX模型中使用`AddCustomOp`算子了。

如果模型中有些算子不被ONNX算子库支持,我们就需要利用ONNXRuntime提供的API手动添加新算子。在官方文档中已经对如何添加定制算子进行了介绍(onnxruntime.ai/docs/reference/operators/add-custom-op.html ),这里我们主要把源码中对应的流程给捋清楚。

在 ONNXRuntime 中添加算子共有两种方式:

  • 第一种方式是首先编译ONNXRuntime,然后利用其暴露出的 API 来添加新的定制算子,这也是本文的主要内容;
  • 第二种方式是在 Contrib 域中添加定制算子,但是添加完成后需要重新编译 ONNXRuntime,这种方式使得编译后的ONNXRuntime二进制库体积增大,本文没有针对这种方式深入介绍;

下面介绍第一种添加算子的方式。
添加定制算子(Custom Operators)主要分为三步:

  1. 创建一个定制算子域(CusttomOpDomain);
  2. 创建一个定制算子(CustomOp),并将该算子添加到定制算子域中;
  3. 将定制算子域添加到 SessionOption 中

首先看看源码中给出的定制算子样例:

// file path: onnxruntime/test/shared_lib/custom_op_utils.h // 首先定义定制算子的核 struct MyCustomKernel { MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/, void* compute_stream) : ort_(ort), compute_stream_(compute_stream) { } void Compute(OrtKernelContext* context); private: Ort::CustomOpApi ort_; void* compute_stream_; }; // 然后定义定制算子的各个操作,各个成员函数均已实现,其中 CreateKernel 会返回前面定义的算子核对象 struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> { explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info, compute_stream_); }; const char* GetName() const { return "Foo"; }; const char* GetExecutionProviderType() const { return provider_; }; size_t GetInputTypeCount() const { return 2; }; ONNXTensorElementDataType GetInputType(size_t /*index*/) const { // Both the inputs need to be necessarily of float type return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; size_t GetOutputTypeCount() const { return 1; }; ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; private: const char* provider_; void* compute_stream_; };

在上面代码中,我们看到定制算子继承自 Ort::CustomOpBase<MyCustomOp, MyCustomKernel>,这种扩展类作为模板基类的模板参数的方式又被称为CRTP,接着深入到这个模板类内部:

// file path: include/onnxruntime/core/session/onnxruntime_cxx_api.h template <typename TOp, typename TKernel> struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); }; OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); }; OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider const char* GetExecutionProviderType() const { return nullptr; } // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below // (inputs and outputs are required by default) OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } };

这里的 CustomOpBase 又继承自 OrtCustomOp

// include/onnxruntime/core/session/onnxruntime_c_api.h struct OrtCustomOp; typedef struct OrtCustomOp OrtCustomOp; struct OrtCustomOp { uint32_t version; // Must be initialized to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, _In_ const OrtKernelInfo* info); // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); // Returns the type of the execution provider, return nullptr to use CPU execution provider const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); // Returns the count and types of the input & output tensors ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); // Op kernel callbacks void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); // Returns the characteristics of the input & output tensors OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); };

可以发现,OrtCustomOp 中定义了定制算子应该实现的模式,其中的一系列回调函数由其派生类一一实现,比如上文提到的 CustomOpBase 在其构造函数中,以 lambda 函数的方式实现各个回调函数。

至此,我们已经完整地梳理了定义定制算子在源码内部是如何实现的,接下来介绍如何将定义好的定制算子使用起来。

从如下官方测试代码开始分析:

如何在 ONNXRuntime 中扩展【推理引擎】以支持自定义算子?

// file path: onnxruntime/test/shared_lib/test_inference.cc TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; std::vector<Input> inputs(1); Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; // prepare expected inputs and outputs std::vector<int64_t> expected_dims_y = {3, 2}; std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; // 创建定制算子(MyCustomOp) #ifdef USE_CUDA cudaStream_t compute_stream = nullptr; // 声明一个 cuda stream cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); // 创建一个 cuda stream MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider, compute_stream}; #else MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider, nullptr}; #endif // 创建定制算子域(CustomOpDomain) Ort::CustomOpDomain custom_op_domain(""); // 在定制算子域中添加定制算子 custom_op_domain.Add(&custom_op); // 进入 TestInference #ifdef USE_CUDA TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr, false, compute_stream); cudaStreamDestroy(compute_stream); #else TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr); #endif }

以上代码需要特别注意的是,需要根据宏(USE_CUDA)用来判断是否使用CUDA。如果使用 CUDA:

  • 当模型运行在GPU上,而插入的是 CPU 定制算子,那么 ONNXRuntime 会在 CPU 定制算子前后分别插入两个操作 MemcpyToHost、MemcpyFromHost,这两个操作负责内存拷贝,即首先从 Device 拷贝到 Host,再从 Host 拷贝到 Device;
  • 如果插入的是 GPU 定制算子,为了确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步,它们必须使用同一个 CUDA 计算流。具体细节在下一个代码继续分析。

这里创建 cuda stream 的方式是 cudaStreamCreateWithFlags,该函数和 cudaStreamCreate 不同,后者在多次调用时是串行方式执行,而前者可同步执行。如果将参数 cudaStreamNonBlocking 替换为 cudaStreamDefault,则 cudaStreamCreateWithFlags 的行为将和 cudaStreamCreate 相同。

无论是否使用CDUA,我们都需要创建定制算子(MyCustomOp)。

进入 TestInference 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc template <typename OutT> static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri, const std::vector<Input>& inputs, const char* output_name, const std::vector<int64_t>& expected_dims_y, const std::vector<OutT>& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const char* custom_op_library_filename, void** library_handle = nullptr, bool test_session_creation_only = false, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; if (provider_type == 1) { #ifdef USE_CUDA std::cout << "Running simple inference with cuda provider" << std::endl; auto cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream); session_options.AppendExecutionProvider_CUDA(cuda_options); #else ORT_UNUSED_PARAMETER(cuda_compute_stream); return; #endif } else if (provider_type == 2) { #ifdef USE_DNNL Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1)); std::cout << "Running simple inference with dnnl provider" << std::endl; #else return; #endif } else if (provider_type == 3) { #ifdef USE_NUPHAR Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options, /*allow_unaligned_buffers*/ 1, "")); std::cout << "Running simple inference with nuphar provider" << std::endl; #else return; #endif } else { std::cout << "Running simple inference with default provider" << std::endl; } if (custom_op_domain_ptr) { session_options.Add(custom_op_domain_ptr); } if (custom_op_library_filename) { Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options, custom_op_library_filename, library_handle)); } // if session creation passes, model loads fine Ort::Session session(env, model_uri.c_str(), session_options); // caller wants to test running the model (not just loading the model) if (!test_session_creation_only) { // Now run auto default_allocator = std::make_unique<MockedOrtAllocator>(); //without preallocated output tensor RunSession<OutT>(default_allocator.get(), session, inputs, output_name, expected_dims_y, expected_values_y, nullptr); //with preallocated output tensor Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(), expected_dims_y.data(), expected_dims_y.size()); //test it twice for (int i = 0; i != 2; ++i) RunSession<OutT>(default_allocator.get(), session, inputs, output_name, expected_dims_y, expected_values_y, &value_y); } }

前文提到,如果对应EP是CUDA,需要确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步。为了实现这一目标,首先通过 CreateDefaultOrtCudaProviderOptionsWithCustomStream 函数将新创建的 CUDA 计算流以 OrtCudaProviderOptions 的形式传递给 SessionOptions:

OrtCUDAProviderOptions cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream); session_options.AppendExecutionProvider_CUDA(cuda_options)

之后,将定制算子域也添加到 SessionOptions 中:

if (custom_op_domain_ptr) { session_options.Add(custom_op_domain_ptr); }

至此,SessionOptions 已经构建完成,下面创建 Session 并通过 model_uri 加载模型:

Ort::Session session(env, model_uri.c_str(), session_options);

这里的(1)Ort::Session 是在 onnxruntime_cxx_api.h 文件中声明的类,(2)对应的构造函数在 onnxruntime_cxx_inline.h 中实现,(3)实现方式是进一步调用 onnxruntime_c_api.h 中定义的 API,该 API 也仅仅是声明,(4)最终对应的实现在 onnxruntime_c_api.cc 文件中:

// (1) include/onnxruntime/core/session/onnxruntime_cxx_api.h struct Session : Base<OrtSession> { explicit Session(std::nullptr_t) {} Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); } // (2) include/onnxruntime/core/session/onnxruntime_cxx_inline.h inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); } // (3) include/onnxruntime/core/session/onnxruntime_c_api.h ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); // (4) onnxruntime/core/session/onnxruntime_c_api.cc ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr<onnxruntime::InferenceSession> sess; OrtStatus* status = nullptr; *out = nullptr; ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); *out = reinterpret_cast<OrtSession*>(sess.release()); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { status = OrtApis::CreateStatus(ORT_FAIL, e.what()); }); } return status; API_IMPL_END }

可以发现,Ort::Session 内部还是调用了 onnxruntime::InferenceSession

扯远了,下面回归主题。

创建 Session 完成之后,便开始运行,进入 RunSession 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc template <typename OutT> void RunSession(OrtAllocator* allocator, Ort::Session& session_object, const std::vector<Input>& inputs, const char* output_name, const std::vector<int64_t>& dims_y, const std::vector<OutT>& values_y, Ort::Value* output_tensor) { // 构建模型输入 std::vector<Ort::Value> ort_inputs; std::vector<const char*> input_names; for (size_t i = 0; i < inputs.size(); i++) { input_names.emplace_back(inputs[i].name); ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); } // 运行 RUN std::vector<Ort::Value> ort_outputs; if (output_tensor) session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, output_tensor, 1); else { ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); ASSERT_EQ(ort_outputs.size(), 1u); output_tensor = &ort_outputs[0]; } auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); ASSERT_EQ(type_info.GetShape(), dims_y); size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), total_len); OutT* f = output_tensor->GetTensorMutableData<OutT>(); for (size_t i = 0; i != total_len; ++i) { ASSERT_EQ(values_y[i], f[i]); } }

这里使用了一些GTest中的断言来判定运行结果是否符合预期。

至此,我们已经完整地分析了定制算子从定义到使用的全部流程。