First test with rust api
This commit is contained in:
parent
4f382eb831
commit
2076b9e127
128
src/lib.rs
128
src/lib.rs
@ -1,3 +1,4 @@
|
|||||||
|
use std::ffi::c_void;
|
||||||
use std::ffi::{self, CString};
|
use std::ffi::{self, CString};
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
||||||
@ -248,7 +249,7 @@ impl Clone for SessionOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
pub fn new(env: Env, model_path: &str, options: SessionOptions) -> Result<Self> {
|
pub fn new(env: &Env, model_path: &str, options: &SessionOptions) -> Result<Self> {
|
||||||
let model_path = CString::new(model_path)?;
|
let model_path = CString::new(model_path)?;
|
||||||
let mut raw = ptr::null_mut();
|
let mut raw = ptr::null_mut();
|
||||||
unsafe {
|
unsafe {
|
||||||
@ -264,9 +265,44 @@ impl Session {
|
|||||||
Ok(Session { raw })
|
Ok(Session { raw })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn input_name(&self, ix: u64) -> Result<OrtString> {
|
||||||
|
let alloc = Allocator::default();
|
||||||
|
let mut raw = ptr::null_mut();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
checked_call!(SessionGetInputName, self.raw, ix, alloc.raw, &mut raw)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(OrtString { raw })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn output_name(&self, ix: u64) -> Result<OrtString> {
|
||||||
|
let alloc = Allocator::default();
|
||||||
|
let mut raw = ptr::null_mut();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
checked_call!(SessionGetOutputName, self.raw, ix, alloc.raw, &mut raw)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(OrtString { raw })
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn run(
|
||||||
|
// &self,
|
||||||
|
// options: &RunOptions,
|
||||||
|
// input_names: InputNames,
|
||||||
|
// inputs: Inputs,
|
||||||
|
// output_names: OutputNames,
|
||||||
|
// ) -> Result<Vec<Value>>
|
||||||
|
// where InputNames: Iterator<Item=&str>,
|
||||||
|
// Inputs: Iterator<Item=&Value>,
|
||||||
|
// OutputNames: Iterator<Item=&str>,
|
||||||
|
// {
|
||||||
|
// }
|
||||||
|
|
||||||
pub fn run_raw(
|
pub fn run_raw(
|
||||||
&self,
|
&self,
|
||||||
options: RunOptions,
|
options: &RunOptions,
|
||||||
input_names: &[&str],
|
input_names: &[&str],
|
||||||
inputs: &[&Value],
|
inputs: &[&Value],
|
||||||
output_names: &[&str],
|
output_names: &[&str],
|
||||||
@ -327,18 +363,94 @@ impl Default for Allocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Allocator {
|
||||||
|
pub unsafe fn free(&self, ptr: *mut c_void) {
|
||||||
|
checked_call!(AllocatorFree, self.raw, ptr).expect("AllocatorFree");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use std::ffi::CStr;
|
||||||
|
use std::os::raw::c_char;
|
||||||
|
|
||||||
|
/// An ort string with the default allocator
|
||||||
|
pub struct OrtString {
|
||||||
|
raw: *const c_char,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for OrtString {
|
||||||
|
type Target = CStr;
|
||||||
|
|
||||||
|
fn deref(&self) -> &CStr {
|
||||||
|
unsafe { CStr::from_ptr(self.raw) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for OrtString {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let alloc = Allocator::default();
|
||||||
|
unsafe { alloc.free(self.raw as _) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl MemoryInfo {
|
impl MemoryInfo {
|
||||||
pub fn cpu_memory_info(alloc_type: AllocatorType, mem_type: MemType) -> Result<Self> {
|
pub fn cpu_memory_info(alloc_type: AllocatorType, mem_type: MemType) -> Result<Self> {
|
||||||
let mut raw = ptr::null_mut();
|
let mut raw = ptr::null_mut();
|
||||||
unsafe {
|
unsafe {
|
||||||
checked_call!(
|
checked_call!(CreateCpuMemoryInfo, alloc_type, mem_type, &mut raw)?;
|
||||||
CreateCpuMemoryInfo,
|
|
||||||
alloc_type,
|
|
||||||
mem_type,
|
|
||||||
&mut raw
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(MemoryInfo { raw })
|
Ok(MemoryInfo { raw })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl RunOptions {
|
||||||
|
pub fn new() -> RunOptions {
|
||||||
|
let mut raw = ptr::null_mut();
|
||||||
|
unsafe {
|
||||||
|
checked_call!(CreateRunOptions, &mut raw).expect("CreateRunOptions");
|
||||||
|
}
|
||||||
|
|
||||||
|
RunOptions { raw }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mat_mul() -> Result<()> {
|
||||||
|
let env = Env::new(LoggingLevel::Warning, "test")?;
|
||||||
|
|
||||||
|
let so = SessionOptions::new()?;
|
||||||
|
|
||||||
|
let model_path = "../onnxruntime/test/testdata/matmul_1.onnx";
|
||||||
|
let session = Session::new(&env, model_path, &so)?;
|
||||||
|
let in_name = session.input_name(0)?;
|
||||||
|
let out_name = session.output_name(0)?;
|
||||||
|
|
||||||
|
let ro = RunOptions::new();
|
||||||
|
|
||||||
|
let mem_info = MemoryInfo::cpu_memory_info(AllocatorType::ArenaAllocator, MemType::Cpu)?;
|
||||||
|
|
||||||
|
let input_names = vec![in_name.to_str().unwrap()];
|
||||||
|
let output_names = vec![out_name.to_str().unwrap()];
|
||||||
|
|
||||||
|
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let input_shape = vec![3, 2];
|
||||||
|
|
||||||
|
let input_tensor = Tensor::new(mem_info, input_shape, input_data)?;
|
||||||
|
|
||||||
|
let output = session.run_raw(&ro, &input_names, &[input_tensor.value()], &output_names)?;
|
||||||
|
|
||||||
|
let output_value = output.into_iter().next().unwrap();
|
||||||
|
let output_tensor = output_value.as_tensor::<f32>().ok().expect("as_tensor");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
&output_tensor[..],
|
||||||
|
&[1. + 2. * 2., 3. + 2. * 4., 5. + 2. * 6.]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -87,12 +87,10 @@ impl TensorTypeAndShapeInfo {
|
|||||||
/// Return the number of elements specified by the tensor shape. Return a negative value if
|
/// Return the number of elements specified by the tensor shape. Return a negative value if
|
||||||
/// unknown (i.e., any dimension is negative.)
|
/// unknown (i.e., any dimension is negative.)
|
||||||
///
|
///
|
||||||
/// ```
|
|
||||||
/// [] -> 1
|
/// [] -> 1
|
||||||
/// [1,3,4] -> 12
|
/// [1,3,4] -> 12
|
||||||
/// [2,0,4] -> 0
|
/// [2,0,4] -> 0
|
||||||
/// [-1,3,4] -> -1
|
/// [-1,3,4] -> -1
|
||||||
/// ```
|
|
||||||
pub fn elem_count(&self) -> isize {
|
pub fn elem_count(&self) -> isize {
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
unsafe {
|
unsafe {
|
||||||
@ -273,4 +271,3 @@ macro_rules! values {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user