First test with rust api
This commit is contained in:
parent
4f382eb831
commit
2076b9e127
128
src/lib.rs
128
src/lib.rs
@ -1,3 +1,4 @@
|
||||
use std::ffi::c_void;
|
||||
use std::ffi::{self, CString};
|
||||
use std::ptr;
|
||||
|
||||
@ -248,7 +249,7 @@ impl Clone for SessionOptions {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(env: Env, model_path: &str, options: SessionOptions) -> Result<Self> {
|
||||
pub fn new(env: &Env, model_path: &str, options: &SessionOptions) -> Result<Self> {
|
||||
let model_path = CString::new(model_path)?;
|
||||
let mut raw = ptr::null_mut();
|
||||
unsafe {
|
||||
@ -264,9 +265,44 @@ impl Session {
|
||||
Ok(Session { raw })
|
||||
}
|
||||
|
||||
pub fn input_name(&self, ix: u64) -> Result<OrtString> {
|
||||
let alloc = Allocator::default();
|
||||
let mut raw = ptr::null_mut();
|
||||
|
||||
unsafe {
|
||||
checked_call!(SessionGetInputName, self.raw, ix, alloc.raw, &mut raw)?;
|
||||
}
|
||||
|
||||
Ok(OrtString { raw })
|
||||
}
|
||||
|
||||
pub fn output_name(&self, ix: u64) -> Result<OrtString> {
|
||||
let alloc = Allocator::default();
|
||||
let mut raw = ptr::null_mut();
|
||||
|
||||
unsafe {
|
||||
checked_call!(SessionGetOutputName, self.raw, ix, alloc.raw, &mut raw)?;
|
||||
}
|
||||
|
||||
Ok(OrtString { raw })
|
||||
}
|
||||
|
||||
// pub fn run(
|
||||
// &self,
|
||||
// options: &RunOptions,
|
||||
// input_names: InputNames,
|
||||
// inputs: Inputs,
|
||||
// output_names: OutputNames,
|
||||
// ) -> Result<Vec<Value>>
|
||||
// where InputNames: Iterator<Item=&str>,
|
||||
// Inputs: Iterator<Item=&Value>,
|
||||
// OutputNames: Iterator<Item=&str>,
|
||||
// {
|
||||
// }
|
||||
|
||||
pub fn run_raw(
|
||||
&self,
|
||||
options: RunOptions,
|
||||
options: &RunOptions,
|
||||
input_names: &[&str],
|
||||
inputs: &[&Value],
|
||||
output_names: &[&str],
|
||||
@ -327,18 +363,94 @@ impl Default for Allocator {
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator {
|
||||
pub unsafe fn free(&self, ptr: *mut c_void) {
|
||||
checked_call!(AllocatorFree, self.raw, ptr).expect("AllocatorFree");
|
||||
}
|
||||
}
|
||||
|
||||
use std::ffi::CStr;
|
||||
use std::os::raw::c_char;
|
||||
|
||||
/// An ort string with the default allocator
|
||||
pub struct OrtString {
|
||||
raw: *const c_char,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for OrtString {
|
||||
type Target = CStr;
|
||||
|
||||
fn deref(&self) -> &CStr {
|
||||
unsafe { CStr::from_ptr(self.raw) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OrtString {
|
||||
fn drop(&mut self) {
|
||||
let alloc = Allocator::default();
|
||||
unsafe { alloc.free(self.raw as _) }
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryInfo {
|
||||
pub fn cpu_memory_info(alloc_type: AllocatorType, mem_type: MemType) -> Result<Self> {
|
||||
let mut raw = ptr::null_mut();
|
||||
unsafe {
|
||||
checked_call!(
|
||||
CreateCpuMemoryInfo,
|
||||
alloc_type,
|
||||
mem_type,
|
||||
&mut raw
|
||||
)?;
|
||||
checked_call!(CreateCpuMemoryInfo, alloc_type, mem_type, &mut raw)?;
|
||||
}
|
||||
|
||||
Ok(MemoryInfo { raw })
|
||||
}
|
||||
}
|
||||
|
||||
impl RunOptions {
|
||||
pub fn new() -> RunOptions {
|
||||
let mut raw = ptr::null_mut();
|
||||
unsafe {
|
||||
checked_call!(CreateRunOptions, &mut raw).expect("CreateRunOptions");
|
||||
}
|
||||
|
||||
RunOptions { raw }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mat_mul() -> Result<()> {
|
||||
let env = Env::new(LoggingLevel::Warning, "test")?;
|
||||
|
||||
let so = SessionOptions::new()?;
|
||||
|
||||
let model_path = "../onnxruntime/test/testdata/matmul_1.onnx";
|
||||
let session = Session::new(&env, model_path, &so)?;
|
||||
let in_name = session.input_name(0)?;
|
||||
let out_name = session.output_name(0)?;
|
||||
|
||||
let ro = RunOptions::new();
|
||||
|
||||
let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?;
|
||||
|
||||
let input_names = vec![in_name.to_str().unwrap()];
|
||||
let output_names = vec![out_name.to_str().unwrap()];
|
||||
|
||||
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let input_shape = vec![3, 2];
|
||||
|
||||
let input_tensor = Tensor::new(mem_info, input_shape, input_data)?;
|
||||
|
||||
let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?;
|
||||
|
||||
let output_value = output.into_iter().next().unwrap();
|
||||
let output_tensor = output_value.as_tensor::<f32>().ok().expect("as_tensor");
|
||||
|
||||
assert_eq!(
|
||||
&output_tensor[..],
|
||||
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -87,12 +87,10 @@ impl TensorTypeAndShapeInfo {
|
||||
/// Return the number of elements specified by the tensor shape. Return a negative value if
|
||||
/// unknown (i.e., any dimension is negative.)
|
||||
///
|
||||
/// ```
|
||||
/// [] -> 1
|
||||
/// [1,3,4] -> 12
|
||||
/// [2,0,4] -> 0
|
||||
/// [-1,3,4] -> -1
|
||||
/// ```
|
||||
pub fn elem_count(&self) -> isize {
|
||||
let mut count = 0;
|
||||
unsafe {
|
||||
@ -273,4 +271,3 @@ macro_rules! values {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user