Use string version in run, seems to add ~300 ns overhead
This commit is contained in:
parent
255adf82ad
commit
198c9b3aa2
@ -1,4 +1,3 @@
|
||||
use std::ffi::CStr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use onnxruntime::*;
|
||||
@ -84,7 +83,7 @@ fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn A
|
||||
fn tensor_with_size(
|
||||
info: &TensorInfo,
|
||||
named_sizes: &mut HashMap<String, usize>,
|
||||
) -> Box<dyn AsRef<Val>> {
|
||||
) -> Box<dyn AsRef<Val> + Sync> {
|
||||
let (ty, dims) = tensor_size(info, named_sizes);
|
||||
use OnnxTensorElementDataType::*;
|
||||
match ty {
|
||||
@ -123,7 +122,7 @@ fn main() -> Result<()> {
|
||||
eprintln!("description: {}", metadata.description());
|
||||
|
||||
let mut input_names: Vec<OrtString> = vec![];
|
||||
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
|
||||
let mut input_tensors: Vec<Box<dyn AsRef<Val> + Sync>> = vec![];
|
||||
|
||||
for (i, input) in session.inputs().enumerate() {
|
||||
if let Some(tensor_info) = input.tensor_info() {
|
||||
@ -151,10 +150,6 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
|
||||
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
|
||||
let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect();
|
||||
|
||||
crossbeam::scope(|s| {
|
||||
let mut workers = vec![];
|
||||
for i in 0..opt.workers {
|
||||
@ -167,22 +162,34 @@ fn main() -> Result<()> {
|
||||
.iter()
|
||||
.map(|(elem_type, size)| tensor_mut(*elem_type, size))
|
||||
.collect();
|
||||
let mut out_vals: Vec<&mut Val> = output_tensors
|
||||
.iter_mut()
|
||||
.map(|x| x.as_mut().as_mut())
|
||||
.collect();
|
||||
|
||||
let inputs = input_names
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref()));
|
||||
|
||||
let outputs = output_names
|
||||
.iter()
|
||||
.zip(output_tensors.iter_mut())
|
||||
.map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut()));
|
||||
|
||||
// warmup run
|
||||
session
|
||||
.run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..])
|
||||
.expect("run");
|
||||
session.run(&ro, inputs, outputs).expect("run");
|
||||
|
||||
let mut times = vec![];
|
||||
for _ in 0..opt.runs {
|
||||
let before = Instant::now();
|
||||
session
|
||||
.run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..])
|
||||
.expect("run");
|
||||
let inputs = input_names
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref()));
|
||||
|
||||
let outputs = output_names
|
||||
.iter()
|
||||
.zip(output_tensors.iter_mut())
|
||||
.map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut()));
|
||||
|
||||
session.run(&ro, inputs, outputs).expect("run");
|
||||
times.push(before.elapsed());
|
||||
}
|
||||
let total: Duration = times.iter().sum();
|
||||
|
@ -564,12 +564,6 @@ impl OrtString {
|
||||
}
|
||||
}
|
||||
|
||||
impl OrtString {
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.to_str().expect("OrtString::as_str")
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for OrtString {}
|
||||
unsafe impl Sync for OrtString {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user