Onnxruntime c接口说明及动态调用示例
背景:
需要onnx模型推理的功能,直接引用onnxruntime代码会引起编译问题。所以考虑动态加载onnxruntime的动态库完成。C++的接口依然需要源码依赖,所以考虑使用onnxruntime的c接口。
1.How to access Onnxruntime C API:
要访问c api,需要拿到 c api的函数指针,而onnxruntime 的所有capi定义在 一个结构体中:
1 2 3 4 5 6 7 8 9 |
struct OrtApi { OrtStatus*(ORT_API_CALL* SetIntraOpNumThreads)(_Inout_ OrtSessionOptions* options, int intra_op_num_threads); OrtStatus*(ORT_API_CALL* SetSessionGraphOptimizationLevel)(_Inout_ OrtSessionOptions* options, GraphOptimizationLevel graph_optimization_level)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetDimensionsCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetSymbolicDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char** dim_params, size_t dim_params_length)NO_EXCEPTION; }; |
上述仅随意列举了用到的函数指针,由于onnxruntime在c api的接口实现中将初始化一个全局静态OrtApi对象,因而这些函数指针并不需要一一获取,只用获取到这个OrtApi对象即可通过访问成员的方式使用对应的函数,极大方便了接口使用。而获取这个对象是通过一个OrtGetApiBase函数实现。见如下定义。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
// Interface definition struct OrtApi; typedef struct OrtApi OrtApi; struct OrtApiBase { const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; // Pass in ORT_API_VERSION const char*(ORT_API_CALL* GetVersionString)() NO_EXCEPTION; }; typedef struct OrtApiBase OrtApiBase; ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase() NO_EXCEPTION; // Interface Implementation // onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString, }; const OrtApiBase* ORT_API_CALL OrtGetApiBase() NO_EXCEPTION { return &ort_api_base; } |
首先通过导出函数OrtGetApiBase函数获取指向一个全局静态结构体:OrtApiBase结构体的指针,该结构体内定义了获取包含所用到的函数指针的全局静态OrtApi结构体对象的函数指针GetApi。
因此可以通过如下代码
1 2 3 4 5 6 |
auto pApiBase = loader.GetFuncPointer<OrtGetApiBasePtr>( spLibHandle, kOrtGetApiBaseName); if( nullptr == pApiBase){ // TODO exception. } const OrtApi* pOrt = pApiBase()->GetApi(ORT_API_VERSION); |
获取需要的OrtApi结构体对象,从而能访问所需的函数指针。
另外onnxruntime定义了一系列接口用到的结构体,以及对应的资源释放函数 ReleaseXXX。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
#define ORT_RUNTIME_CLASS(X) \ struct Ort##X; \ typedef struct Ort##X Ort##X; // ORT_API(void, OrtRelease##X, _Frees_ptr_opt_ Ort##X* input); #define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) // The actual types defined have an Ort prefix ORT_RUNTIME_CLASS(Env); ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success ORT_RUNTIME_CLASS(MemoryInfo); ORT_RUNTIME_CLASS(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); ORT_RUNTIME_CLASS(CustomOpDomain); // Another partial definition struct OrtApi { ORT_CLASS_RELEASE(Env); ORT_CLASS_RELEASE(Status); // nullptr for Status* indicates success ORT_CLASS_RELEASE(MemoryInfo); ORT_CLASS_RELEASE(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) ORT_CLASS_RELEASE(Value); ORT_CLASS_RELEASE(RunOptions); ORT_CLASS_RELEASE(TypeInfo); ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); ORT_CLASS_RELEASE(SessionOptions); ORT_CLASS_RELEASE(CustomOpDomain); }; |
2. How to Use the interface:
2.1 How to create session:
有如下两个接口可以创建session:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
// TODO: document the path separator convention? '/' vs '\' // TODO: should specify the access characteristics of model_path. Is this read // only during the execution of OrtCreateSession, or does the OrtSession // retain a handle to the file/directory and continue to access throughout the // OrtSession lifetime? // What sort of access is needed to model_path : read or read/write? OrtStatus*(ORT_API_CALL* CreateSession)( _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* CreateSessionFromArray)( _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out)NO_EXCEPTION; |
这两接口区别是模型的传入类型不一样,由于这里我们不是从文件加载模型进行推理,所以这里采用CreateSessionFromArray接口。
a. const OrtEnv* env:
要使用该接口,需要先创建执行环境,
1 2 3 4 5 6 7 |
/** * \param out Should be freed by `OrtReleaseEnv` after use */ OrtStatus*(ORT_API_CALL* CreateEnv)( OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; |
可以用该函数创建,参数就是日志级别和日志标示。
b. const void* model_data && size_t model_data_length:
model_data即模型二进制表示的指针,而要从ModelProto获取二进制表示,包括model_data_len是二进制表示的大小信息,都需要从ONNX protobuf的定义里找对应的函数,代码如下:
1 2 3 4 5 |
size_t model_size = mp.ByteSizeLong(); std::vector<uint8_t> mp_buff(model_size); mp.SerializeToArray(mp_buff.data(), mp_buff.size()); c. const OrtSessionOptions* options: |
会话配置,可以通过下述接口获取,
1 2 3 4 5 6 |
/** * \return A pointer of the newly created object. The pointer should be freed * by OrtReleaseSessionOptions after use */ OrtStatus*(ORT_API_CALL* CreateSessionOptions)( _Outptr_ OrtSessionOptions** options)NO_EXCEPTION; |
至此,一个session便准备完毕。
2.2 How to use OrtValue
OrtValue是onnxruntime c api中最常见的类型,用来封装所有的数据传递结构。最主要的就是tensor,所以api中也提供了一系列相关的函数接口用于tensor相关的类型转换。
由于我们是在ONNX中调用onnxruntime的api,而且是用动态加载的方式,所以我们只用关心从onnx tensor 到OrtValue的相互转换,而OrtValue内部封装的是onnxruntime的tensor类型。
2.3 How to infer
主要是用这个接口:
OrtStatus*( * Run)( OrtSession* sess, const OrtRunOptions* run_options, const char* const* input_names, const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len, OrtValue** output);
3. How to Convert ONNX tensor to OrtValue
先说模型infer过程中首先会遇到onnx下input tensor转换成OrtValue的问题,
核心是调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
/** * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value * \param out Should be freed by calling OrtReleaseValue * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx */ OrtStatus*(ORT_API_CALL* CreateTensorAsOrtValue)(_Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out)NO_EXCEPTION; /** * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. OrtReleaseValue won't release p_data. * \param out Should be freed by calling OrtReleaseValue */ OrtStatus*(ORT_API_CALL* CreateTensorWithDataAsOrtValue)(_In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out)NO_EXCEPTION; |
CreateTensorAsOrtValue / CreateTensorWithDataAsOrtValue 这两个函数进行OrtValue类型的tensor创建,这两个函数的区别一是是否需要由onnxruntime进行内存分配及其内存管理的职责。
由于input tensor 已经开辟并保持了输入数据,我们不需要onnxruntime重复开辟内存存放数据,因而我们是使用 CreateTensorWithDataAsOrtValue 函数去创建,对函数原型进行简化(去掉说明的宏)
OrtStatus*(* CreateTensorWithDataAsOrtValue)(const OrtMemoryInfo* info, void* p_data, size_t p_data_len, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, OrtValue** out);
依次分析入参:
a. const OrtMemoryInfo* info:可以通过接口:
OrtStatus*(* CreateCpuMemoryInfo)(enum OrtAllocatorType type, enum OrtMemType mem_type1, OrtMemoryInfo** out);
获取,而CreateCpuMemoryInfo接口只需要传入两个枚举类型即可获取。
b. void* p_data: && size_t p_data_len:
此即原onnx tensor的data指针,由于onnx tensor的数据存放成员因tensor的类型不同而不同,所以我们需要一个辅助函数来获取该data实际存储的位置,
p_data_len就是 p_data指向的内存空间的大小,根据onnxruntime 的 api实现得知此大小是字节大小,并且是0字节对齐的。
所以函数内实现先根据 ONNX tensor的类型获取到存储数据的成员vector容器,从而获取数据指针及数据数量,并根据类型计算出总大小。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
bool GetONNXTensorData( const Tensor& tensor, const void** data_ptr, size_t* data_len) { *data_len = tensor.size_from_dim(0); if (tensor.is_raw_data()) { *data_ptr = tensor.raw().c_str(); *data_len = tensor.raw().length(); } switch (tensor.elem_type()) { ... ... case TensorProto_DataType_INT64: *data_ptr = tensor.data<int64_t>(); *data_len *= sizeof(int64_t); case TensorProto_DataType_UINT32: case TensorProto_DataType_UINT64: *data_ptr = tensor.data<uint64_t>(); *data_len *= sizeof(uint64_t); case TensorProto_DataType_FLOAT: *data_ptr = tensor.data<float>(); *data_len *= sizeof(float); case TensorProto_DataType_DOUBLE: *data_ptr = tensor.data<double>(); *data_len *= sizeof(double); default: return false; } return true; } |
c. const int64_t* shape && size_t shape_len:
shape和shape_len及张量维度(通道大小)和维度数量(通道数量)信息,这个信息OrtValue 和 Onnx tensor的定义没有什么区别,所以直接获取:
input_tensors[i].sizes().data(),
input_tensors[i].sizes().size(),
d. ONNXTensorElementDataType type:
此参数即OrtValue对应tensor的元素类型信息,这里用一个函数进行对应枚举类型的转换
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
ONNXTensorElementDataType TensorProtoDataTypeToONNXTensorElementDataType( const TensorProto_DataType elem_type_) { ONNXTensorElementDataType type; switch (elem_type_) { case TensorProto_DataType_UNDEFINED: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; case TensorProto_DataType_BOOL: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; case TensorProto_DataType_INT8: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; ... ... } } |
4. How to Convert OrtValue to ONNX tensor
infer得到的OrtValue 需要转成 ONNX tensor,onnxruntime并没有提供转换接口,需要从OrtValue中获取数据指针/大小/类型等信息构造onnx tensor
a.获取数据指针:
// This function doesn’t work with string tensor
// this is a no-copy method whose pointer is only valid until the backing
// OrtValue is free’d.
OrtStatus*( * GetTensorMutableData)( OrtValue* value, void** out);
参数不再赘述。
b. 获取tensor类型和维度信息:
首先通过GetTensorTypeAndShape获取类型和维度信息对象:
/**
* \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use
*/
OrtStatus*(* GetTensorTypeAndShape)( const OrtValue* value, OrtTensorTypeAndShapeInfo** out);
传入即infer返回的output (OrtValue),传出OrtTensorTypeAndShapeInfo对象。
然后获取数据类型:
OrtStatus*( * GetTensorElementType)( const OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType* out);
拿到的类型信息再做一下转换:
1 2 3 4 5 6 7 8 9 10 11 12 |
TensorProto_DataType ONNXTensorElementDataTypeToTensorProtoDataType( const ONNXTensorElementDataType elem_type_) { TensorProto_DataType type; switch (elem_type_) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: type = TensorProto_DataType_UNDEFINED; break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: type = TensorProto_DataType_BOOL; break; ... } |
接着获取数据维度信息,包括维度数量和维度信息:
OrtStatus*(* GetDimensionsCount)( const OrtTensorTypeAndShapeInfo* info, size_t* out);
OrtStatus*(* GetDimensions)( const OrtTensorTypeAndShapeInfo* info, int64_t* dim_values, size_t dim_values_length);
拿到上述信息后开始构造 ONNX tensor,这里还是要根据类型进行数据存储:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
bool SetTensorData(Tensor* tensor, void* data_ptr, size_t length) { switch (tensor->elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT32: { memcpy( tensor->int32s().data(), data_ptr, length * sizeof(std::remove_reference<decltype( tensor->int32s())>::type::value_type)); break; } .... } |
1 |
等上述过程结束,别忘了释放所有资源。借助RAII方法管理资源,参见github代码。
上述过程我简单写了个代码,放到GITHUB上:
代码地址:https://github.com/atp798/onnx_dynamic_load.git