Change the internals of OrtString for cheap &str and &Cstr conversions
This commit is contained in:
parent
8622da3088
commit
c5bdf9c9c0
@ -1,4 +1,4 @@
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::ffi::CStr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use onnxruntime::*;
|
||||
@ -122,24 +122,24 @@ fn main() -> Result<()> {
|
||||
eprintln!("domain: {}", metadata.domain());
|
||||
eprintln!("description: {}", metadata.description());
|
||||
|
||||
let mut input_names: Vec<CString> = vec![];
|
||||
let mut input_names: Vec<OrtString> = vec![];
|
||||
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
|
||||
|
||||
for (i, input) in session.inputs().enumerate() {
|
||||
if let Some(tensor_info) = input.tensor_info() {
|
||||
input_names.push(input.name().to_owned());
|
||||
input_names.push(input.name());
|
||||
input_tensors.push(tensor_with_size(&tensor_info, &mut map));
|
||||
} else {
|
||||
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
|
||||
}
|
||||
}
|
||||
|
||||
let mut output_names: Vec<CString> = vec![];
|
||||
let mut output_names: Vec<OrtString> = vec![];
|
||||
let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec<usize>)> = vec![];
|
||||
|
||||
for (i, output) in session.outputs().enumerate() {
|
||||
if let Some(tensor_info) = output.tensor_info() {
|
||||
output_names.push(output.name().to_owned());
|
||||
output_names.push(output.name());
|
||||
output_sizes.push(tensor_size(&tensor_info, &mut map));
|
||||
} else {
|
||||
println!(
|
||||
|
73
src/lib.rs
73
src/lib.rs
@ -266,7 +266,7 @@ impl<'a> ArgumentInfo<'a> {
|
||||
call!(@unsafe @ptr @expect SessionGetOverridableInitializerName, sess, ix, alloc.as_ptr())
|
||||
}
|
||||
};
|
||||
OrtString { raw }
|
||||
unsafe { OrtString::new(raw) }
|
||||
}
|
||||
|
||||
/// The index of this argument
|
||||
@ -451,22 +451,22 @@ impl ModelMetadata {
|
||||
let alloc = Allocator::default();
|
||||
let raw =
|
||||
call!(@unsafe @ptr @expect ModelMetadataGetProducerName, self.raw, alloc.as_ptr());
|
||||
OrtString { raw }
|
||||
unsafe { OrtString::new(raw) }
|
||||
}
|
||||
pub fn graph_name(&self) -> OrtString {
|
||||
let alloc = Allocator::default();
|
||||
let raw = call!(@unsafe @ptr @expect ModelMetadataGetGraphName, self.raw, alloc.as_ptr());
|
||||
OrtString { raw }
|
||||
unsafe { OrtString::new(raw) }
|
||||
}
|
||||
pub fn domain(&self) -> OrtString {
|
||||
let alloc = Allocator::default();
|
||||
let raw = call!(@unsafe @ptr @expect ModelMetadataGetDomain, self.raw, alloc.as_ptr());
|
||||
OrtString { raw }
|
||||
unsafe { OrtString::new(raw) }
|
||||
}
|
||||
pub fn description(&self) -> OrtString {
|
||||
let alloc = Allocator::default();
|
||||
let raw = call!(@unsafe @ptr @expect ModelMetadataGetDescription, self.raw, alloc.as_ptr());
|
||||
OrtString { raw }
|
||||
unsafe { OrtString::new(raw) }
|
||||
}
|
||||
pub fn lookup_custom(&self, key: &str) -> Option<OrtString> {
|
||||
let alloc = Allocator::default();
|
||||
@ -480,7 +480,7 @@ impl ModelMetadata {
|
||||
if raw.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(OrtString { raw })
|
||||
Some(unsafe { OrtString::new(raw) })
|
||||
}
|
||||
}
|
||||
pub fn version(&self) -> i64 {
|
||||
@ -495,23 +495,58 @@ fn cstr_ptrs(slice: &[&CStr]) -> Vec<*const c_char> {
|
||||
|
||||
/// An ort string with the default allocator
|
||||
pub struct OrtString {
|
||||
raw: *const c_char,
|
||||
raw: *const u8,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl OrtString {
|
||||
unsafe fn new(raw: *const c_char) -> OrtString {
|
||||
// get length and confirm it's valid utf8
|
||||
let cstr = CStr::from_ptr(raw);
|
||||
let str = std::str::from_utf8(cstr.to_bytes()).expect("OrtString::new");
|
||||
OrtString {
|
||||
len: str.len(),
|
||||
raw: raw as *const u8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
|
||||
}
|
||||
|
||||
pub fn as_bytes_with_nul(&self) -> &[u8] {
|
||||
unsafe { std::slice::from_raw_parts(self.raw, self.len + 1) }
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
unsafe { std::slice::from_raw_parts(self.raw, self.len) }
|
||||
}
|
||||
|
||||
pub fn as_c_str(&self) -> &CStr {
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for OrtString {}
|
||||
unsafe impl Sync for OrtString {}
|
||||
|
||||
impl std::ops::Deref for OrtString {
|
||||
type Target = CStr;
|
||||
impl AsRef<CStr> for OrtString {
|
||||
fn as_ref(&self) -> &CStr {
|
||||
self.as_c_str()
|
||||
}
|
||||
}
|
||||
|
||||
fn deref(&self) -> &CStr {
|
||||
unsafe { CStr::from_ptr(self.raw) }
|
||||
impl std::ops::Deref for OrtString {
|
||||
type Target = str;
|
||||
|
||||
fn deref(&self) -> &str {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for OrtString {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "{}", self.to_string_lossy())
|
||||
write!(fmt, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
@ -677,7 +712,12 @@ mod tests {
|
||||
let input_tensor = Tensor::new(&[3, 2], input_data)?;
|
||||
|
||||
// immutable version
|
||||
let output = session.run_raw(&ro, &[&in_name], &[input_tensor.value()], &[&out_name])?;
|
||||
let output = session.run_raw(
|
||||
&ro,
|
||||
&[&in_name.as_c_str()],
|
||||
&[input_tensor.value()],
|
||||
&[out_name.as_c_str()],
|
||||
)?;
|
||||
|
||||
let output_value = output.into_iter().next().unwrap();
|
||||
let output_tensor = output_value.as_tensor::<f32>().ok().expect("as_tensor");
|
||||
@ -690,7 +730,12 @@ mod tests {
|
||||
// mutable version
|
||||
let mut output_tensor = Tensor::<f32>::init(&[3, 1], 0.0)?;
|
||||
|
||||
run!(session, &ro, "X": &input_tensor, "Y": &mut output_tensor)?;
|
||||
run!(
|
||||
session,
|
||||
&ro,
|
||||
in_name: &input_tensor,
|
||||
"Y": &mut output_tensor
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
&output_tensor[..],
|
||||
|
Loading…
Reference in New Issue
Block a user