From 68f0217c4a780f7e48aa381e799334deee6cbf2f Mon Sep 17 00:00:00 2001 From: Christopher Chalmers Date: Sat, 6 Jun 2020 09:59:21 +0100 Subject: [PATCH] Add CustomOp api --- src/custom_op.rs | 173 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 9 ++- 2 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 src/custom_op.rs diff --git a/src/custom_op.rs b/src/custom_op.rs new file mode 100644 index 0000000..8673743 --- /dev/null +++ b/src/custom_op.rs @@ -0,0 +1,173 @@ +use std::ffi::c_void; + +use super::*; + +impl CustomOpDomain { + pub fn new(domain: &str) -> CustomOpDomain { + let domain = CString::new(domain).unwrap(); + let raw = call!(@unsafe @ptr @expect CreateCustomOpDomain, domain.as_ptr()); + CustomOpDomain { raw } + } + + pub fn add(&mut self, op: &mut CustomOpWithContext) { + call!(@unsafe @expect CustomOpDomain_Add, self.raw, op.as_ptr()); + } +} + +pub trait Kernel { + /// The compute function for a custom op kernel. The KernelContext is used to obtain the inputs + /// and create the outputs. + fn compute(&mut self, context: &mut KernelContext); +} + +pub trait CustomOperation { + type CustomKernel: Kernel; + fn create_kernel(&mut self, info: &KernelInfo) -> Self::CustomKernel; + fn name(&self) -> &str; + fn execution_provider_type(&self) -> Option<&str>; + fn input_type(&self, index: usize) -> OnnxTensorElementDataType; + fn input_type_count(&self) -> usize; + fn output_type(&self, index: usize) -> OnnxTensorElementDataType; + fn output_type_count(&self) -> usize; +} + +impl KernelInfo { + pub fn attribute_float(&self, name: &str) -> Result { + let name = CString::new(name).expect("KernelInfo::attribute_float"); + call!(@unsafe @arg 0.0; KernelInfoGetAttribute_float, self, name.as_ptr()) + } + pub fn attribute_i64(&self, name: &str) -> Result { + let name = CString::new(name).expect("KernelInfo::attribute_i64"); + call!(@unsafe @int KernelInfoGetAttribute_int64, self, name.as_ptr()) + } +} + +impl KernelContext { + pub fn input_count(&self) -> usize { + call!(@unsafe @int @expect KernelContext_GetInputCount, self) as usize + } + pub fn output_count(&self) -> usize { + call!(@unsafe @int @expect KernelContext_GetOutputCount, self) as usize + } + pub fn get_input(&self, index: usize) -> &Val { + let raw = + call!(@unsafe @arg ptr::null(); @expect KernelContext_GetInput, self, index as u64); + unsafe { &*(raw as *const Val) } + } + pub fn get_output(&mut self, index: usize, dims: &[usize]) -> &mut Val { + let dims: Vec = dims.iter().map(|&d| d as i64).collect(); + let raw = call!(@unsafe @ptr @expect KernelContext_GetOutput, + self, + index as u64, + dims.as_ptr(), + dims.len() as u64); + unsafe { &mut *(raw as *mut Val) } + } +} + +/// A custom op with the context following it. This is done in a C struct because the internal +/// functions of `CustomOp` provide the pointer of the `CustomOp`. Adding a offset to this we can +/// get back the context. +#[repr(C)] +pub struct CustomOpWithContext { + ort_custom_op: CustomOp, + name: CString, + execution_provider_type: Option, + ctx: Ctx, +} + +impl CustomOpWithContext { + pub fn as_ptr(&mut self) -> *mut CustomOp { + &mut self.ort_custom_op + } +} + +#[allow(non_snake_case)] +pub fn create_custom_op(ctx: Op) -> CustomOpWithContext { + unsafe extern "C" fn CreateKernel( + op: *mut CustomOp, + api: *const sys::Api, + info: *const KernelInfo, + ) -> *mut c_void { + let _ = api; + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + let info = info.as_ref().unwrap(); + let kernel = Box::new(op_with_ctx.ctx.create_kernel(info)); + Box::into_raw(kernel) as _ + } + + unsafe extern "C" fn GetName(op: *mut CustomOp) -> *const c_char { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx.name.as_ptr() + } + + unsafe extern "C" fn GetExecutionProviderType( + op: *mut CustomOp, + ) -> *const c_char { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx + .execution_provider_type + .as_ref() + .map_or(ptr::null(), |cstr| cstr.as_ptr()) + } + + unsafe extern "C" fn GetInputType( + op: *mut CustomOp, + index: u64, + ) -> OnnxTensorElementDataType { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx.ctx.input_type(index as usize) + } + + unsafe extern "C" fn GetInputTypeCount(op: *mut CustomOp) -> u64 { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx.ctx.input_type_count() as u64 + } + + unsafe extern "C" fn GetOutputType( + op: *mut CustomOp, + index: u64, + ) -> OnnxTensorElementDataType { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx.ctx.output_type(index as usize) + } + + unsafe extern "C" fn GetOutputTypeCount(op: *mut CustomOp) -> u64 { + let op_with_ctx = (op as *mut CustomOpWithContext).as_mut().unwrap(); + op_with_ctx.ctx.output_type_count() as u64 + } + + unsafe extern "C" fn KernelCompute( + kernel: *mut c_void, + context: *mut KernelContext, + ) { + let kernel = (kernel as *mut Op::CustomKernel).as_mut().unwrap(); + kernel.compute(context.as_mut().unwrap()) + } + + unsafe extern "C" fn KernelDestroy(kernel: *mut c_void) { + drop(Box::from_raw(kernel as *mut Op::CustomKernel)); + } + + let ort_custom_op = CustomOp { + version: sys::ORT_API_VERSION, + CreateKernel: Some(CreateKernel::), + GetName: Some(GetName::), + GetExecutionProviderType: Some(GetExecutionProviderType::), + GetInputType: Some(GetInputType::), + GetInputTypeCount: Some(GetInputTypeCount::), + GetOutputType: Some(GetOutputType::), + GetOutputTypeCount: Some(GetOutputTypeCount::), + KernelCompute: Some(KernelCompute::), + KernelDestroy: Some(KernelDestroy::), + }; + + CustomOpWithContext { + ort_custom_op, + name: CString::new(ctx.name()).unwrap(), + execution_provider_type: ctx + .execution_provider_type() + .map(|str| CString::new(str).unwrap()), + ctx, + } +} diff --git a/src/lib.rs b/src/lib.rs index 38e712b..1495d9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,13 +7,16 @@ use std::ptr; pub mod sys; // Re-export enums pub use sys::{ - AllocatorType, ErrorCode, ExecutionMode, GraphOptimizationLevel, LoggingLevel, MemType, - OnnxTensorElementDataType, OnnxType, + AllocatorType, CustomOp, ErrorCode, ExecutionMode, GraphOptimizationLevel, KernelContext, + KernelInfo, LoggingLevel, MemType, OnnxTensorElementDataType, OnnxType, }; #[macro_use] mod api; +mod custom_op; +pub use custom_op::*; + mod allocator; pub use allocator::Allocator; @@ -177,6 +180,8 @@ impl SessionOptions { { SetSessionGraphOptimizationLevel }; fn set_intra_op_num_threads(intra_op_num_threads: i32) { SetIntraOpNumThreads }; fn set_inter_op_num_threads(intra_op_num_threads: i32) { SetInterOpNumThreads }; + /// XXX not yet safe becuase of lifetimes + fn add_custom_op_domain(custom_op_domain: &mut CustomOpDomain | .raw()) { AddCustomOpDomain }; } }