Change the internals of OrtString for cheap &str and &Cstr conversions

This commit is contained in:
Christopher Chalmers 2020-06-01 00:22:27 +01:00
parent 8622da3088
commit c5bdf9c9c0
2 changed files with 64 additions and 19 deletions

View File

@ -1,4 +1,4 @@
use std::ffi::{CStr, CString};
use std::ffi::CStr;
use std::time::{Duration, Instant};
use onnxruntime::*;
@ -122,24 +122,24 @@ fn main() -> Result<()> {
eprintln!("domain: {}", metadata.domain());
eprintln!("description: {}", metadata.description());
let mut input_names: Vec<CString> = vec![];
let mut input_names: Vec<OrtString> = vec![];
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
for (i, input) in session.inputs().enumerate() {
if let Some(tensor_info) = input.tensor_info() {
input_names.push(input.name().to_owned());
input_names.push(input.name());
input_tensors.push(tensor_with_size(&tensor_info, &mut map));
} else {
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
}
}
let mut output_names: Vec<CString> = vec![];
let mut output_names: Vec<OrtString> = vec![];
let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec<usize>)> = vec![];
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
output_names.push(output.name().to_owned());
output_names.push(output.name());
output_sizes.push(tensor_size(&tensor_info, &mut map));
} else {
println!(

View File

@ -266,7 +266,7 @@ impl<'a> ArgumentInfo<'a> {
call!(@unsafe @ptr @expect SessionGetOverridableInitializerName, sess, ix, alloc.as_ptr())
}
};
OrtString { raw }
unsafe { OrtString::new(raw) }
}
/// The index of this argument
@ -451,22 +451,22 @@ impl ModelMetadata {
let alloc = Allocator::default();
let raw =
call!(@unsafe @ptr @expect ModelMetadataGetProducerName, self.raw, alloc.as_ptr());
OrtString { raw }
unsafe { OrtString::new(raw) }
}
pub fn graph_name(&self) -> OrtString {
let alloc = Allocator::default();
let raw = call!(@unsafe @ptr @expect ModelMetadataGetGraphName, self.raw, alloc.as_ptr());
OrtString { raw }
unsafe { OrtString::new(raw) }
}
pub fn domain(&self) -> OrtString {
let alloc = Allocator::default();
let raw = call!(@unsafe @ptr @expect ModelMetadataGetDomain, self.raw, alloc.as_ptr());
OrtString { raw }
unsafe { OrtString::new(raw) }
}
pub fn description(&self) -> OrtString {
let alloc = Allocator::default();
let raw = call!(@unsafe @ptr @expect ModelMetadataGetDescription, self.raw, alloc.as_ptr());
OrtString { raw }
unsafe { OrtString::new(raw) }
}
pub fn lookup_custom(&self, key: &str) -> Option<OrtString> {
let alloc = Allocator::default();
@ -480,7 +480,7 @@ impl ModelMetadata {
if raw.is_null() {
None
} else {
Some(OrtString { raw })
Some(unsafe { OrtString::new(raw) })
}
}
pub fn version(&self) -> i64 {
@ -495,23 +495,58 @@ fn cstr_ptrs(slice: &[&CStr]) -> Vec<*const c_char> {
/// An ort string with the default allocator
pub struct OrtString {
raw: *const c_char,
raw: *const u8,
len: usize,
}
impl OrtString {
unsafe fn new(raw: *const c_char) -> OrtString {
// get length and confirm it's valid utf8
let cstr = CStr::from_ptr(raw);
let str = std::str::from_utf8(cstr.to_bytes()).expect("OrtString::new");
OrtString {
len: str.len(),
raw: raw as *const u8,
}
}
pub fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
}
pub fn as_bytes_with_nul(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.raw, self.len + 1) }
}
pub fn as_bytes(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.raw, self.len) }
}
pub fn as_c_str(&self) -> &CStr {
unsafe { CStr::from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) }
}
}
unsafe impl Send for OrtString {}
unsafe impl Sync for OrtString {}
impl std::ops::Deref for OrtString {
type Target = CStr;
impl AsRef<CStr> for OrtString {
fn as_ref(&self) -> &CStr {
self.as_c_str()
}
}
fn deref(&self) -> &CStr {
unsafe { CStr::from_ptr(self.raw) }
impl std::ops::Deref for OrtString {
type Target = str;
fn deref(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for OrtString {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}", self.to_string_lossy())
write!(fmt, "{}", self.as_str())
}
}
@ -677,7 +712,12 @@ mod tests {
let input_tensor = Tensor::new(&[3, 2], input_data)?;
// immutable version
let output = session.run_raw(&ro, &[&in_name], &[input_tensor.value()], &[&out_name])?;
let output = session.run_raw(
&ro,
&[&in_name.as_c_str()],
&[input_tensor.value()],
&[out_name.as_c_str()],
)?;
let output_value = output.into_iter().next().unwrap();
let output_tensor = output_value.as_tensor::<f32>().ok().expect("as_tensor");
@ -690,7 +730,12 @@ mod tests {
// mutable version
let mut output_tensor = Tensor::<f32>::init(&[3, 1], 0.0)?;
run!(session, &ro, "X": &input_tensor, "Y": &mut output_tensor)?;
run!(
session,
&ro,
in_name: &input_tensor,
"Y": &mut output_tensor
)?;
assert_eq!(
&output_tensor[..],