diff --git a/src/bin/run.rs b/src/bin/run.rs index 26a744b..43741a2 100644 --- a/src/bin/run.rs +++ b/src/bin/run.rs @@ -1,4 +1,3 @@ -use std::ffi::CStr; use std::time::{Duration, Instant}; use onnxruntime::*; @@ -84,7 +83,7 @@ fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box, -) -> Box> { +) -> Box + Sync> { let (ty, dims) = tensor_size(info, named_sizes); use OnnxTensorElementDataType::*; match ty { @@ -123,7 +122,7 @@ fn main() -> Result<()> { eprintln!("description: {}", metadata.description()); let mut input_names: Vec = vec![]; - let mut input_tensors: Vec>> = vec![]; + let mut input_tensors: Vec + Sync>> = vec![]; for (i, input) in session.inputs().enumerate() { if let Some(tensor_info) = input.tensor_info() { @@ -151,10 +150,6 @@ fn main() -> Result<()> { } } - let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect(); - let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect(); - let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect(); - crossbeam::scope(|s| { let mut workers = vec![]; for i in 0..opt.workers { @@ -167,22 +162,34 @@ fn main() -> Result<()> { .iter() .map(|(elem_type, size)| tensor_mut(*elem_type, size)) .collect(); - let mut out_vals: Vec<&mut Val> = output_tensors - .iter_mut() - .map(|x| x.as_mut().as_mut()) - .collect(); + + let inputs = input_names + .iter() + .zip(input_tensors.iter()) + .map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref())); + + let outputs = output_names + .iter() + .zip(output_tensors.iter_mut()) + .map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut())); // warmup run - session - .run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..]) - .expect("run"); + session.run(&ro, inputs, outputs).expect("run"); let mut times = vec![]; for _ in 0..opt.runs { let before = Instant::now(); - session - .run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..]) - .expect("run"); + let inputs = input_names + .iter() + .zip(input_tensors.iter()) + .map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref())); + + let outputs = output_names + .iter() + .zip(output_tensors.iter_mut()) + .map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut())); + + session.run(&ro, inputs, outputs).expect("run"); times.push(before.elapsed()); } let total: Duration = times.iter().sum(); diff --git a/src/lib.rs b/src/lib.rs index 654aaec..a56494d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -564,12 +564,6 @@ impl OrtString { } } -impl OrtString { - pub fn as_str(&self) -> &str { - self.to_str().expect("OrtString::as_str") - } -} - unsafe impl Send for OrtString {} unsafe impl Sync for OrtString {}