From c5bdf9c9c04b1df507a4600f83ff07d2a8c23416 Mon Sep 17 00:00:00 2001 From: Christopher Chalmers Date: Mon, 1 Jun 2020 00:22:27 +0100 Subject: [PATCH] Change the internals of OrtString for cheap &str and &Cstr conversions --- src/bin/run.rs | 10 +++---- src/lib.rs | 73 ++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/src/bin/run.rs b/src/bin/run.rs index d6c2948..26a744b 100644 --- a/src/bin/run.rs +++ b/src/bin/run.rs @@ -1,4 +1,4 @@ -use std::ffi::{CStr, CString}; +use std::ffi::CStr; use std::time::{Duration, Instant}; use onnxruntime::*; @@ -122,24 +122,24 @@ fn main() -> Result<()> { eprintln!("domain: {}", metadata.domain()); eprintln!("description: {}", metadata.description()); - let mut input_names: Vec = vec![]; + let mut input_names: Vec = vec![]; let mut input_tensors: Vec>> = vec![]; for (i, input) in session.inputs().enumerate() { if let Some(tensor_info) = input.tensor_info() { - input_names.push(input.name().to_owned()); + input_names.push(input.name()); input_tensors.push(tensor_with_size(&tensor_info, &mut map)); } else { println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type()); } } - let mut output_names: Vec = vec![]; + let mut output_names: Vec = vec![]; let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec)> = vec![]; for (i, output) in session.outputs().enumerate() { if let Some(tensor_info) = output.tensor_info() { - output_names.push(output.name().to_owned()); + output_names.push(output.name()); output_sizes.push(tensor_size(&tensor_info, &mut map)); } else { println!( diff --git a/src/lib.rs b/src/lib.rs index 38e712b..acc30c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -266,7 +266,7 @@ impl<'a> ArgumentInfo<'a> { call!(@unsafe @ptr @expect SessionGetOverridableInitializerName, sess, ix, alloc.as_ptr()) } }; - OrtString { raw } + unsafe { OrtString::new(raw) } } /// The index of this argument @@ -451,22 +451,22 @@ impl ModelMetadata { let alloc = Allocator::default(); let raw = call!(@unsafe @ptr @expect ModelMetadataGetProducerName, self.raw, alloc.as_ptr()); - OrtString { raw } + unsafe { OrtString::new(raw) } } pub fn graph_name(&self) -> OrtString { let alloc = Allocator::default(); let raw = call!(@unsafe @ptr @expect ModelMetadataGetGraphName, self.raw, alloc.as_ptr()); - OrtString { raw } + unsafe { OrtString::new(raw) } } pub fn domain(&self) -> OrtString { let alloc = Allocator::default(); let raw = call!(@unsafe @ptr @expect ModelMetadataGetDomain, self.raw, alloc.as_ptr()); - OrtString { raw } + unsafe { OrtString::new(raw) } } pub fn description(&self) -> OrtString { let alloc = Allocator::default(); let raw = call!(@unsafe @ptr @expect ModelMetadataGetDescription, self.raw, alloc.as_ptr()); - OrtString { raw } + unsafe { OrtString::new(raw) } } pub fn lookup_custom(&self, key: &str) -> Option { let alloc = Allocator::default(); @@ -480,7 +480,7 @@ impl ModelMetadata { if raw.is_null() { None } else { - Some(OrtString { raw }) + Some(unsafe { OrtString::new(raw) }) } } pub fn version(&self) -> i64 { @@ -495,23 +495,58 @@ fn cstr_ptrs(slice: &[&CStr]) -> Vec<*const c_char> { /// An ort string with the default allocator pub struct OrtString { - raw: *const c_char, + raw: *const u8, + len: usize, +} + +impl OrtString { + unsafe fn new(raw: *const c_char) -> OrtString { + // get length and confirm it's valid utf8 + let cstr = CStr::from_ptr(raw); + let str = std::str::from_utf8(cstr.to_bytes()).expect("OrtString::new"); + OrtString { + len: str.len(), + raw: raw as *const u8, + } + } + + pub fn as_str(&self) -> &str { + unsafe { std::str::from_utf8_unchecked(self.as_bytes()) } + } + + pub fn as_bytes_with_nul(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.raw, self.len + 1) } + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.raw, self.len) } + } + + pub fn as_c_str(&self) -> &CStr { + unsafe { CStr::from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) } + } } unsafe impl Send for OrtString {} unsafe impl Sync for OrtString {} -impl std::ops::Deref for OrtString { - type Target = CStr; +impl AsRef for OrtString { + fn as_ref(&self) -> &CStr { + self.as_c_str() + } +} - fn deref(&self) -> &CStr { - unsafe { CStr::from_ptr(self.raw) } +impl std::ops::Deref for OrtString { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() } } impl fmt::Display for OrtString { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{}", self.to_string_lossy()) + write!(fmt, "{}", self.as_str()) } } @@ -677,7 +712,12 @@ mod tests { let input_tensor = Tensor::new(&[3, 2], input_data)?; // immutable version - let output = session.run_raw(&ro, &[&in_name], &[input_tensor.value()], &[&out_name])?; + let output = session.run_raw( + &ro, + &[&in_name.as_c_str()], + &[input_tensor.value()], + &[out_name.as_c_str()], + )?; let output_value = output.into_iter().next().unwrap(); let output_tensor = output_value.as_tensor::().ok().expect("as_tensor"); @@ -690,7 +730,12 @@ mod tests { // mutable version let mut output_tensor = Tensor::::init(&[3, 1], 0.0)?; - run!(session, &ro, "X": &input_tensor, "Y": &mut output_tensor)?; + run!( + session, + &ro, + in_name: &input_tensor, + "Y": &mut output_tensor + )?; assert_eq!( &output_tensor[..],