Only warn once when a size is missing

This commit is contained in:
Christopher Chalmers 2020-05-31 23:32:38 +01:00
parent e249d09c56
commit 01143dc038

View File

@ -1,3 +1,4 @@
use std::ffi::{CStr, CString};
use std::time::{Duration, Instant};
use onnxruntime::*;
@ -29,7 +30,7 @@ struct Opt {
use std::collections::HashMap;
fn key_val_parse(str: &str) -> HashMap<&str, usize> {
fn key_val_parse(str: &str) -> HashMap<String, usize> {
let mut map = HashMap::new();
if str.is_empty() {
return map;
@ -43,67 +44,62 @@ fn key_val_parse(str: &str) -> HashMap<&str, usize> {
.parse()
.expect("parse error");
assert!(iter.next().is_none(), "more than 1 =");
map.insert(key, val);
map.insert(key.to_owned(), val);
}
map
}
fn tensor_with_size(info: &TensorInfo, named_sizes: &HashMap<&str, usize>) -> Box<dyn AsRef<Val>> {
use OnnxTensorElementDataType::*;
let dims = info
.symbolic_dims()
.map(|d| match d {
SymbolicDim::Symbolic(name) => {
let name = name.to_str().unwrap();
*named_sizes.get(&name).unwrap_or_else(|| {
eprintln!("name {} not specified, setting to 1", name);
&1
})
}
SymbolicDim::Fixed(x) => x,
})
.collect::<Vec<usize>>();
match info.elem_type() {
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
fn tensor_with_size_mut(
/// Get the size of a tensor, substituting symbolic dimentions.
fn tensor_size(
info: &TensorInfo,
named_sizes: &HashMap<&str, usize>,
) -> Box<dyn AsMut<Val>> {
use OnnxTensorElementDataType::*;
named_sizes: &mut HashMap<String, usize>,
) -> (OnnxTensorElementDataType, Vec<usize>) {
let dims = info
.symbolic_dims()
.map(|d| match d {
SymbolicDim::Symbolic(name) => {
let name = name.to_str().unwrap();
*named_sizes.get(&name).unwrap_or_else(|| {
named_sizes.get(name).cloned().unwrap_or_else(|| {
eprintln!("name {} not specified, setting to 1", name);
&1
named_sizes.insert(name.to_owned(), 1);
1
})
}
SymbolicDim::Fixed(x) => x,
})
.collect::<Vec<usize>>();
match info.elem_type() {
.collect();
(info.elem_type(), dims)
}
fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn AsMut<Val>> {
use OnnxTensorElementDataType::*;
match elem_type {
Float => Box::new(Tensor::<f32>::init(dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
fn tensor_with_size(
info: &TensorInfo,
named_sizes: &mut HashMap<String, usize>,
) -> Box<dyn AsRef<Val>> {
let (ty, dims) = tensor_size(info, named_sizes);
use OnnxTensorElementDataType::*;
match ty {
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
use std::ffi::{CStr, CString};
fn main() -> Result<()> {
let env = Env::new(LoggingLevel::Fatal, "test")?;
let opt = Opt::from_args();
let so = SessionOptions::new()?;
let map = if let Some(dims) = &opt.dims {
let mut map = if let Some(dims) = &opt.dims {
key_val_parse(dims)
} else {
HashMap::new()
@ -131,15 +127,28 @@ fn main() -> Result<()> {
for (i, input) in session.inputs().enumerate() {
if let Some(tensor_info) = input.tensor_info() {
input_names.push(input.name().to_owned());
input_tensors.push(tensor_with_size(&tensor_info, &map));
input_tensors.push(tensor_with_size(&tensor_info, &mut map));
} else {
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
}
}
let output_names: Vec<CString> = session
.outputs()
.map(|output| output.name().to_owned())
.collect();
let mut output_names: Vec<CString> = vec![];
let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec<usize>)> = vec![];
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
output_names.push(output.name().to_owned());
output_sizes.push(tensor_size(&tensor_info, &mut map));
} else {
println!(
"output {}: {:?} {:?}",
i,
&*output.name(),
output.onnx_type()
);
}
}
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
@ -153,19 +162,10 @@ fn main() -> Result<()> {
let i = i;
let ro = RunOptions::new();
// allocate output vectors
let mut output_tensors: Vec<Box<dyn AsMut<Val>>> = vec![];
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
output_tensors.push(tensor_with_size_mut(&tensor_info, &map));
} else {
println!(
"output {}: {:?} {:?}",
i,
&*output.name(),
output.onnx_type()
);
}
}
let mut output_tensors: Vec<_> = output_sizes
.iter()
.map(|(elem_type, size)| tensor_mut(*elem_type, size))
.collect();
let mut out_vals: Vec<&mut Val> = output_tensors
.iter_mut()
.map(|x| x.as_mut().as_mut())