Add some examples

This commit is contained in:
Christopher Chalmers 2020-05-31 01:24:52 +01:00
parent 4b6afb6170
commit 3bdc2c93d0
2 changed files with 250 additions and 0 deletions

84
examples/info.rs Normal file
View File

@ -0,0 +1,84 @@
use onnxruntime::*;
use structopt::StructOpt;
#[derive(StructOpt)]
struct Opt {
/// The path to the gru onnx file
// #[structopt(long)]
onnx: Vec<String>,
// #[structopt(long)]
// ort_profile: Option<String>,
// #[structopt(long, default_value="1")]
// batch_size: i64,
// #[structopt(long, default_value="1")]
// runs: usize,
// #[structopt(long, default_value="1")]
// workers: usize,
}
fn main() -> Result<()> {
let env = Env::new(LoggingLevel::Fatal, "test")?;
let opt = Opt::from_args();
let so = SessionOptions::new()?;
for path in &opt.onnx {
println!("model {:?}", path);
let session = match Session::new(&env, path, &so) {
Ok(sess) => sess,
Err(err) => {
eprintln!("error: {}\n", err);
continue;
}
};
let metadata = session.metadata();
eprintln!("name: {}", metadata.producer_name());
eprintln!("graph_name: {}", metadata.graph_name());
eprintln!("domain: {}", metadata.domain());
eprintln!("description: {}", metadata.description());
for (i, input) in session.inputs().enumerate() {
if let Some(tensor_info) = input.tensor_info() {
let dims = tensor_info.symbolic_dims().collect::<Vec<_>>();
println!(
"input {}: {:?} {:?} {:?}",
i,
&*input.name(),
dims,
tensor_info.elem_type()
)
} else {
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
}
}
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
let dims = tensor_info.symbolic_dims().collect::<Vec<_>>();
println!(
"output {}: {:?} {:?} {:?}",
i,
&*output.name(),
dims,
tensor_info.elem_type()
)
} else {
println!(
"input {}: {:?} {:?}",
i,
&*output.name(),
output.onnx_type()
);
}
}
for (i, output) in session.overridable_initializers().enumerate() {
println!("init {}: {:?}", i, &*output.name())
}
println!();
}
Ok(())
}

166
examples/run.rs Normal file
View File

@ -0,0 +1,166 @@
use onnxruntime::*;
use structopt::StructOpt;
#[derive(StructOpt)]
struct Opt {
/// The path to the gru onnx file
// #[structopt(long)]
onnx: Vec<String>,
#[structopt(long, default_value = "")]
dims: String,
// #[structopt(long, default_value="1")]
// workers: usize,
}
use std::collections::HashMap;
fn key_val_parse(str: &str) -> HashMap<&str, usize> {
let mut map = HashMap::new();
if str.is_empty() {
return map;
}
for key_val in str.split(',') {
let mut iter = key_val.split('=');
let key = iter.next().expect("no =");
let val = iter
.next()
.expect("nothing after =")
.parse()
.expect("parse error");
assert!(iter.next().is_none(), "more than 1 =");
map.insert(key, val);
}
map
}
fn tensor_with_size(info: &TensorInfo, named_sizes: &HashMap<&str, usize>) -> Box<dyn AsRef<Val>> {
use OnnxTensorElementDataType::*;
let dims = info
.symbolic_dims()
.map(|d| match d {
SymbolicDim::Symbolic(name) => {
let name = name.to_str().unwrap();
*named_sizes.get(&name).unwrap_or_else(|| {
eprintln!("name {} not specified, setting to 1", name);
&1
})
}
SymbolicDim::Fixed(x) => x,
})
.collect::<Vec<usize>>();
match info.elem_type() {
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
fn tensor_with_size_mut(
info: &TensorInfo,
named_sizes: &HashMap<&str, usize>,
) -> Box<dyn AsMut<Val>> {
use OnnxTensorElementDataType::*;
let dims = info
.symbolic_dims()
.map(|d| match d {
SymbolicDim::Symbolic(name) => {
let name = name.to_str().unwrap();
*named_sizes.get(&name).unwrap_or_else(|| {
eprintln!("name {} not specified, setting to 1", name);
&1
})
}
SymbolicDim::Fixed(x) => x,
})
.collect::<Vec<usize>>();
match info.elem_type() {
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
use std::ffi::{CStr, CString};
fn main() -> Result<()> {
let env = Env::new(LoggingLevel::Fatal, "test")?;
let opt = Opt::from_args();
let so = SessionOptions::new()?;
let map = key_val_parse(&opt.dims);
for path in &opt.onnx {
println!("model {:?}", path);
let session = match Session::new(&env, path, &so) {
Ok(sess) => sess,
Err(err) => {
eprintln!("error: {}\n", err);
continue;
}
};
let metadata = session.metadata();
eprintln!("name: {}", metadata.producer_name());
eprintln!("graph_name: {}", metadata.graph_name());
eprintln!("domain: {}", metadata.domain());
eprintln!("description: {}", metadata.description());
let mut input_names: Vec<CString> = vec![];
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
let mut output_names: Vec<CString> = vec![];
let mut output_tensors: Vec<Box<dyn AsMut<Val>>> = vec![];
for (i, input) in session.inputs().enumerate() {
if let Some(tensor_info) = input.tensor_info() {
input_names.push(input.name().to_owned());
input_tensors.push(tensor_with_size(&tensor_info, &map));
} else {
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
}
}
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
output_names.push(output.name().to_owned());
output_tensors.push(tensor_with_size_mut(&tensor_info, &map));
} else {
println!(
"output {}: {:?} {:?}",
i,
&*output.name(),
output.onnx_type()
);
}
}
let so = RunOptions::new();
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect();
let mut out_vals: Vec<&mut Val> = output_tensors
.iter_mut()
.map(|x| x.as_mut().as_mut())
.collect();
session
.run_mut(&so, &in_names, &in_vals[..], &out_names, &mut out_vals[..])
.expect("run");
// pub fn run_mut(
// &self,
// options: &RunOptions,
// input_names: &[&CStr],
// inputs: &[&Val],
// output_names: &[&CStr],
// outputs: &mut [&mut Val],
for (i, output) in session.overridable_initializers().enumerate() {
println!("init {}: {:?}", i, &*output.name())
}
println!();
}
Ok(())
}