From 2076b9e127e305b0efc16f610aa61fc6420a992c Mon Sep 17 00:00:00 2001 From: Christopher Chalmers Date: Sun, 17 May 2020 13:36:07 +0100 Subject: [PATCH] First test with rust api --- src/lib.rs | 128 +++++++++++++++++++++++++++++++++++++++++++++++---- src/value.rs | 3 -- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 95aef8a..1ea0105 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +use std::ffi::c_void; use std::ffi::{self, CString}; use std::ptr; @@ -248,7 +249,7 @@ impl Clone for SessionOptions { } impl Session { - pub fn new(env: Env, model_path: &str, options: SessionOptions) -> Result { + pub fn new(env: &Env, model_path: &str, options: &SessionOptions) -> Result { let model_path = CString::new(model_path)?; let mut raw = ptr::null_mut(); unsafe { @@ -264,9 +265,44 @@ impl Session { Ok(Session { raw }) } + pub fn input_name(&self, ix: u64) -> Result { + let alloc = Allocator::default(); + let mut raw = ptr::null_mut(); + + unsafe { + checked_call!(SessionGetInputName, self.raw, ix, alloc.raw, &mut raw)?; + } + + Ok(OrtString { raw }) + } + + pub fn output_name(&self, ix: u64) -> Result { + let alloc = Allocator::default(); + let mut raw = ptr::null_mut(); + + unsafe { + checked_call!(SessionGetOutputName, self.raw, ix, alloc.raw, &mut raw)?; + } + + Ok(OrtString { raw }) + } + + // pub fn run( + // &self, + // options: &RunOptions, + // input_names: InputNames, + // inputs: Inputs, + // output_names: OutputNames, + // ) -> Result> + // where InputNames: Iterator, + // Inputs: Iterator, + // OutputNames: Iterator, + // { + // } + pub fn run_raw( &self, - options: RunOptions, + options: &RunOptions, input_names: &[&str], inputs: &[&Value], output_names: &[&str], @@ -327,18 +363,94 @@ impl Default for Allocator { } } +impl Allocator { + pub unsafe fn free(&self, ptr: *mut c_void) { + checked_call!(AllocatorFree, self.raw, ptr).expect("AllocatorFree"); + } +} + +use std::ffi::CStr; +use std::os::raw::c_char; + +/// An ort string with the default allocator +pub struct OrtString { + raw: *const c_char, +} + +impl std::ops::Deref for OrtString { + type Target = CStr; + + fn deref(&self) -> &CStr { + unsafe { CStr::from_ptr(self.raw) } + } +} + +impl Drop for OrtString { + fn drop(&mut self) { + let alloc = Allocator::default(); + unsafe { alloc.free(self.raw as _) } + } +} + impl MemoryInfo { pub fn cpu_memory_info(alloc_type: AllocatorType, mem_type: MemType) -> Result { let mut raw = ptr::null_mut(); unsafe { - checked_call!( - CreateCpuMemoryInfo, - alloc_type, - mem_type, - &mut raw - )?; + checked_call!(CreateCpuMemoryInfo, alloc_type, mem_type, &mut raw)?; } Ok(MemoryInfo { raw }) } } + +impl RunOptions { + pub fn new() -> RunOptions { + let mut raw = ptr::null_mut(); + unsafe { + checked_call!(CreateRunOptions, &mut raw).expect("CreateRunOptions"); + } + + RunOptions { raw } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mat_mul() -> Result<()> { + let env = Env::new(LoggingLevel::Warning, "test")?; + + let so = SessionOptions::new()?; + + let model_path = "../onnxruntime/test/testdata/matmul_1.onnx"; + let session = Session::new(&env, model_path, &so)?; + let in_name = session.input_name(0)?; + let out_name = session.output_name(0)?; + + let ro = RunOptions::new(); + + let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?; + + let input_names = vec![in_name.to_str().unwrap()]; + let output_names = vec![out_name.to_str().unwrap()]; + + let input_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let input_shape = vec![3, 2]; + + let input_tensor = Tensor::new(mem_info, input_shape, input_data)?; + + let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?; + + let output_value = output.into_iter().next().unwrap(); + let output_tensor = output_value.as_tensor::().ok().expect("as_tensor"); + + assert_eq!( + &output_tensor[..], + &[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.] + ); + + Ok(()) + } +} diff --git a/src/value.rs b/src/value.rs index b612681..6e5e1f0 100644 --- a/src/value.rs +++ b/src/value.rs @@ -87,12 +87,10 @@ impl TensorTypeAndShapeInfo { /// Return the number of elements specified by the tensor shape. Return a negative value if /// unknown (i.e., any dimension is negative.) /// - /// ``` /// [] -> 1 /// [1,3,4] -> 12 /// [2,0,4] -> 0 /// [-1,3,4] -> -1 - /// ``` pub fn elem_count(&self) -> isize { let mut count = 0; unsafe { @@ -273,4 +271,3 @@ macro_rules! values { } }; } -