Use string version in run, seems to add ~300 ns overhead
This commit is contained in:
parent
255adf82ad
commit
198c9b3aa2
@ -1,4 +1,3 @@
|
|||||||
use std::ffi::CStr;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use onnxruntime::*;
|
use onnxruntime::*;
|
||||||
@ -84,7 +83,7 @@ fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn A
|
|||||||
fn tensor_with_size(
|
fn tensor_with_size(
|
||||||
info: &TensorInfo,
|
info: &TensorInfo,
|
||||||
named_sizes: &mut HashMap<String, usize>,
|
named_sizes: &mut HashMap<String, usize>,
|
||||||
) -> Box<dyn AsRef<Val>> {
|
) -> Box<dyn AsRef<Val> + Sync> {
|
||||||
let (ty, dims) = tensor_size(info, named_sizes);
|
let (ty, dims) = tensor_size(info, named_sizes);
|
||||||
use OnnxTensorElementDataType::*;
|
use OnnxTensorElementDataType::*;
|
||||||
match ty {
|
match ty {
|
||||||
@ -123,7 +122,7 @@ fn main() -> Result<()> {
|
|||||||
eprintln!("description: {}", metadata.description());
|
eprintln!("description: {}", metadata.description());
|
||||||
|
|
||||||
let mut input_names: Vec<OrtString> = vec![];
|
let mut input_names: Vec<OrtString> = vec![];
|
||||||
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
|
let mut input_tensors: Vec<Box<dyn AsRef<Val> + Sync>> = vec![];
|
||||||
|
|
||||||
for (i, input) in session.inputs().enumerate() {
|
for (i, input) in session.inputs().enumerate() {
|
||||||
if let Some(tensor_info) = input.tensor_info() {
|
if let Some(tensor_info) = input.tensor_info() {
|
||||||
@ -151,10 +150,6 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
|
|
||||||
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
|
|
||||||
let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect();
|
|
||||||
|
|
||||||
crossbeam::scope(|s| {
|
crossbeam::scope(|s| {
|
||||||
let mut workers = vec![];
|
let mut workers = vec![];
|
||||||
for i in 0..opt.workers {
|
for i in 0..opt.workers {
|
||||||
@ -167,22 +162,34 @@ fn main() -> Result<()> {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|(elem_type, size)| tensor_mut(*elem_type, size))
|
.map(|(elem_type, size)| tensor_mut(*elem_type, size))
|
||||||
.collect();
|
.collect();
|
||||||
let mut out_vals: Vec<&mut Val> = output_tensors
|
|
||||||
.iter_mut()
|
let inputs = input_names
|
||||||
.map(|x| x.as_mut().as_mut())
|
.iter()
|
||||||
.collect();
|
.zip(input_tensors.iter())
|
||||||
|
.map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref()));
|
||||||
|
|
||||||
|
let outputs = output_names
|
||||||
|
.iter()
|
||||||
|
.zip(output_tensors.iter_mut())
|
||||||
|
.map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut()));
|
||||||
|
|
||||||
// warmup run
|
// warmup run
|
||||||
session
|
session.run(&ro, inputs, outputs).expect("run");
|
||||||
.run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..])
|
|
||||||
.expect("run");
|
|
||||||
|
|
||||||
let mut times = vec![];
|
let mut times = vec![];
|
||||||
for _ in 0..opt.runs {
|
for _ in 0..opt.runs {
|
||||||
let before = Instant::now();
|
let before = Instant::now();
|
||||||
session
|
let inputs = input_names
|
||||||
.run_mut(&ro, &in_names, &in_vals[..], &out_names, &mut out_vals[..])
|
.iter()
|
||||||
.expect("run");
|
.zip(input_tensors.iter())
|
||||||
|
.map(|(nm, val)| (nm.as_str(), val.as_ref().as_ref()));
|
||||||
|
|
||||||
|
let outputs = output_names
|
||||||
|
.iter()
|
||||||
|
.zip(output_tensors.iter_mut())
|
||||||
|
.map(|(nm, val)| (nm.as_str(), val.as_mut().as_mut()));
|
||||||
|
|
||||||
|
session.run(&ro, inputs, outputs).expect("run");
|
||||||
times.push(before.elapsed());
|
times.push(before.elapsed());
|
||||||
}
|
}
|
||||||
let total: Duration = times.iter().sum();
|
let total: Duration = times.iter().sum();
|
||||||
|
@ -564,12 +564,6 @@ impl OrtString {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OrtString {
|
|
||||||
pub fn as_str(&self) -> &str {
|
|
||||||
self.to_str().expect("OrtString::as_str")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl Send for OrtString {}
|
unsafe impl Send for OrtString {}
|
||||||
unsafe impl Sync for OrtString {}
|
unsafe impl Sync for OrtString {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user