First test with rust api

This commit is contained in:
Christopher Chalmers 2020-05-17 13:36:07 +01:00
parent 4f382eb831
commit 2076b9e127
2 changed files with 120 additions and 11 deletions

View File

@ -1,3 +1,4 @@
use std::ffi::c_void;
use std::ffi::{self, CString};
use std::ptr;
@ -248,7 +249,7 @@ impl Clone for SessionOptions {
}
impl Session {
pub fn new(env: Env, model_path: &str, options: SessionOptions) -> Result<Self> {
pub fn new(env: &Env, model_path: &str, options: &SessionOptions) -> Result<Self> {
let model_path = CString::new(model_path)?;
let mut raw = ptr::null_mut();
unsafe {
@ -264,9 +265,44 @@ impl Session {
Ok(Session { raw })
}
pub fn input_name(&self, ix: u64) -> Result<OrtString> {
let alloc = Allocator::default();
let mut raw = ptr::null_mut();
unsafe {
checked_call!(SessionGetInputName, self.raw, ix, alloc.raw, &mut raw)?;
}
Ok(OrtString { raw })
}
pub fn output_name(&self, ix: u64) -> Result<OrtString> {
let alloc = Allocator::default();
let mut raw = ptr::null_mut();
unsafe {
checked_call!(SessionGetOutputName, self.raw, ix, alloc.raw, &mut raw)?;
}
Ok(OrtString { raw })
}
// pub fn run(
// &self,
// options: &RunOptions,
// input_names: InputNames,
// inputs: Inputs,
// output_names: OutputNames,
// ) -> Result<Vec<Value>>
// where InputNames: Iterator<Item=&str>,
// Inputs: Iterator<Item=&Value>,
// OutputNames: Iterator<Item=&str>,
// {
// }
pub fn run_raw(
&self,
options: RunOptions,
options: &RunOptions,
input_names: &[&str],
inputs: &[&Value],
output_names: &[&str],
@ -327,18 +363,94 @@ impl Default for Allocator {
}
}
impl Allocator {
pub unsafe fn free(&self, ptr: *mut c_void) {
checked_call!(AllocatorFree, self.raw, ptr).expect("AllocatorFree");
}
}
use std::ffi::CStr;
use std::os::raw::c_char;
/// An ort string with the default allocator
pub struct OrtString {
raw: *const c_char,
}
impl std::ops::Deref for OrtString {
type Target = CStr;
fn deref(&self) -> &CStr {
unsafe { CStr::from_ptr(self.raw) }
}
}
impl Drop for OrtString {
fn drop(&mut self) {
let alloc = Allocator::default();
unsafe { alloc.free(self.raw as _) }
}
}
impl MemoryInfo {
pub fn cpu_memory_info(alloc_type: AllocatorType, mem_type: MemType) -> Result<Self> {
let mut raw = ptr::null_mut();
unsafe {
checked_call!(
CreateCpuMemoryInfo,
alloc_type,
mem_type,
&mut raw
)?;
checked_call!(CreateCpuMemoryInfo, alloc_type, mem_type, &mut raw)?;
}
Ok(MemoryInfo { raw })
}
}
impl RunOptions {
pub fn new() -> RunOptions {
let mut raw = ptr::null_mut();
unsafe {
checked_call!(CreateRunOptions, &mut raw).expect("CreateRunOptions");
}
RunOptions { raw }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mat_mul() -> Result<()> {
let env = Env::new(LoggingLevel::Warning, "test")?;
let so = SessionOptions::new()?;
let model_path = "../onnxruntime/test/testdata/matmul_1.onnx";
let session = Session::new(&env, model_path, &so)?;
let in_name = session.input_name(0)?;
let out_name = session.output_name(0)?;
let ro = RunOptions::new();
let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?;
let input_names = vec![in_name.to_str().unwrap()];
let output_names = vec![out_name.to_str().unwrap()];
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let input_shape = vec![3, 2];
let input_tensor = Tensor::new(mem_info, input_shape, input_data)?;
let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?;
let output_value = output.into_iter().next().unwrap();
let output_tensor = output_value.as_tensor::<f32>().ok().expect("as_tensor");
assert_eq!(
&output_tensor[..],
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
);
Ok(())
}
}

View File

@ -87,12 +87,10 @@ impl TensorTypeAndShapeInfo {
/// Return the number of elements specified by the tensor shape. Return a negative value if
/// unknown (i.e., any dimension is negative.)
///
/// ```
/// [] -> 1
/// [1,3,4] -> 12
/// [2,0,4] -> 0
/// [-1,3,4] -> -1
/// ```
pub fn elem_count(&self) -> isize {
let mut count = 0;
unsafe {
@ -273,4 +271,3 @@ macro_rules! values {
}
};
}