Some fixes
This commit is contained in:
parent
08fb5dbf8c
commit
25785a1ada
208
Cargo.lock
generated
208
Cargo.lock
generated
@ -1,5 +1,11 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "adler32"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "567b077b825e468cc974f0020d4082ee6e03132512f207ef1a02fd5d00d1f32d"
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.10"
|
||||
@ -65,6 +71,18 @@ version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37fa13df2292ecb479ec23aa06f4507928bef07839be9ef15281411076629431"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.54"
|
||||
@ -112,6 +130,21 @@ dependencies = [
|
||||
"vec_map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color_quant"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0dbbb57365263e881e805dc77d94697c9118fd94d8da011240555aa7b23445bd"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba125de2af0df55319f41944744ad91c71113bf74a4646efff39afe1f6842db1"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.7.3"
|
||||
@ -183,6 +216,22 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deflate"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7e5d2a2273fed52a7f947ee55b092c4057025d7a3e04e5ecdbd25d6c3fb1bd7"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.7.1"
|
||||
@ -196,6 +245,16 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "471d90201b3b223f3451cd4ad53e34295f16a1df17b1edf3736d47761c3981af"
|
||||
dependencies = [
|
||||
"color_quant",
|
||||
"lzw",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.0"
|
||||
@ -229,6 +288,34 @@ dependencies = [
|
||||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.23.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5b0553fec6407d63fe2975b794dfb099f3f790bdc958823851af37b26404ab4"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"gif",
|
||||
"jpeg-decoder",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"png",
|
||||
"scoped_threadpool",
|
||||
"tiff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jpeg-decoder"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b47b4c4e017b01abdc5bcc126d2d1002e5a75bbe3ce73f9f4f311a916363704"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
@ -266,6 +353,12 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lzw"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d947cbb889ed21c2a84be6ffbaebf5b4e0f4340638cba0444907e38b56be084"
|
||||
|
||||
[[package]]
|
||||
name = "maybe-uninit"
|
||||
version = "2.0.0"
|
||||
@ -287,6 +380,15 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "791daaae1ed6889560f8c4359194f56648355540573244a5448a83ba1ecc7435"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "5.1.1"
|
||||
@ -297,6 +399,57 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a6e6b7c748f995c4c29c5f5ae0248536e04a5739927c74ec0fa564805094b9f"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5b4d7360f362cfb50dde8143501e6940b22f644be75a4cc90b2d81968908138"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "0.1.0"
|
||||
@ -304,6 +457,7 @@ dependencies = [
|
||||
"bindgen",
|
||||
"crossbeam",
|
||||
"heck",
|
||||
"image",
|
||||
"lazy_static",
|
||||
"structopt",
|
||||
]
|
||||
@ -314,6 +468,18 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.16.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34ccdd66f6fe4b2433b07e4728e9a013e43233120427046e93ceb709c3a439bf"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"crc32fast",
|
||||
"deflate",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.2"
|
||||
@ -364,6 +530,31 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62f02856753d04e03e26929f820d0a0a337ebe71f849801eea335d464b349080"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"crossbeam-deque",
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e92e15d89083484e11353891f1af602cc661426deb9564c298b270c726973280"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-queue",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.3.7"
|
||||
@ -388,6 +579,12 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "scoped_threadpool"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@ -479,6 +676,17 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f3b8a87c4da944c3f27e5943289171ac71a6150a79ff6bacfff06d159dfff2f"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"lzw",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.6.0"
|
||||
|
@ -10,6 +10,9 @@ lazy_static = "1.4"
|
||||
structopt = "0.3.14"
|
||||
crossbeam = "0.7.3"
|
||||
|
||||
[dev-dependencies]
|
||||
image = "0.23.6"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.53.2"
|
||||
heck = "0.3"
|
||||
|
4
build.rs
4
build.rs
@ -27,6 +27,10 @@ fn main() {
|
||||
.clang_args(clang_args)
|
||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
|
||||
.whitelist_function("OrtGetApiBase")
|
||||
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_CPU")
|
||||
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_OpenVINO")
|
||||
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_CUDA")
|
||||
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_Tensorrt")
|
||||
.whitelist_var("ORT_.*")
|
||||
.whitelist_recursively(true)
|
||||
.blacklist_type("__int64_t")
|
||||
|
@ -1 +1,6 @@
|
||||
#include <onnxruntime_c_api.h>
|
||||
#include <onnxruntime/core/session/onnxruntime_c_api.h>
|
||||
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) ORT_ALL_ARGS_NONNULL;
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id);
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, const char* device_id);
|
||||
|
356
examples/yolo.rs
Normal file
356
examples/yolo.rs
Normal file
@ -0,0 +1,356 @@
|
||||
use std::ffi::CStr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use onnxruntime::*;
|
||||
use structopt::{clap, StructOpt};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[structopt(
|
||||
name = "run",
|
||||
about = "Run a benchmark on an onnx model. Each worker runs the model in a loop in its own
|
||||
thead. Once done it will print the average time to run the model.",
|
||||
setting = clap::AppSettings::ColoredHelp
|
||||
)]
|
||||
#[derive(StructOpt)]
|
||||
struct Opt {
|
||||
/// The path to the onnx files to benchmark
|
||||
onnx: Vec<String>,
|
||||
|
||||
/// A comma separated list of symbolic_dimension=value. If a symbolic dimension is not
|
||||
/// specified, 1 will be used.
|
||||
#[structopt(long)]
|
||||
dims: Option<String>,
|
||||
|
||||
/// The number of worker threads to spawn
|
||||
#[structopt(long, default_value = "1")]
|
||||
workers: usize,
|
||||
|
||||
/// The number of runs each worker will
|
||||
#[structopt(long, default_value = "1")]
|
||||
runs: usize,
|
||||
}
|
||||
|
||||
// #[derive(Debug, Clone)]
|
||||
// struct ModelInfo {
|
||||
// pub path: PathBuf,
|
||||
// pub name: String,
|
||||
// pub description: String,
|
||||
// pub version: i64,
|
||||
// pub input_shape: Vec<i64>,
|
||||
// pub output_shape: Vec<i64>,
|
||||
// pub input_name: String,
|
||||
// pub output_name: String,
|
||||
// }
|
||||
|
||||
// struct OnnxModelConfig {
|
||||
// pub model_path: PathBuf,
|
||||
// pub parallel: bool,
|
||||
// pub use_cuda: bool,
|
||||
// pub use_cpu: bool,
|
||||
// pub use_tensorrt: bool,
|
||||
// pub cpu_use_arena:bool,
|
||||
// pub tensorrt_device_index: i32,
|
||||
// pub cuda_device_index: i32,
|
||||
// pub logging_level: LoggingLevel,
|
||||
// pub logging_id: String,
|
||||
// }
|
||||
|
||||
// struct OnnxModel {
|
||||
// info: Arc<ModelInfo>,
|
||||
// env: Env,
|
||||
// options: SessionOptions,
|
||||
// session: Session,
|
||||
// }
|
||||
|
||||
// impl OnnxModel {
|
||||
// pub fn new(config: OnnxModelConfig) -> Result<Self, Error> {
|
||||
// let env = Env::new(config.logging_level, &config.logging_id)?;
|
||||
// let mut options = SessionOptions::new()?;
|
||||
|
||||
// if config.parallel {
|
||||
// options.set_execution_mode(ExecutionMode::Parallel)?;
|
||||
// }
|
||||
|
||||
// if config.use_cpu {
|
||||
// options.add_cpu(config.cpu_use_arena);
|
||||
// }
|
||||
|
||||
// if config.use_cuda {
|
||||
// options.add_cuda(config.cuda_device_index);
|
||||
// }
|
||||
|
||||
// if config.use_tensorrt {
|
||||
// options.add_tensorrt(config.tensorrt_device_index);
|
||||
// }
|
||||
|
||||
// let session = Session::new(&env, config.model_path.to_str().unwrap(), &options)?;
|
||||
// let metadata = session.metadata();
|
||||
// let inputs = session.inputs();
|
||||
// let outputs = session.outputs();
|
||||
|
||||
// if inputs.len() != 1 {
|
||||
// return Err(Error::ModelWrongInputsCount);
|
||||
// }
|
||||
|
||||
// if outputs.len() != 1 {
|
||||
// return Err(Error::ModelWrongOutputsCount);
|
||||
// }
|
||||
|
||||
// let input = session.input(0);
|
||||
// let input_info = input
|
||||
// .tensor_info()
|
||||
// .or_else(|| Error::ModelWrongInputType)?;
|
||||
// let input_name = input.name().to_string();
|
||||
// let input_shape: Vec<i64> = input_info
|
||||
// .symbolic_dims()
|
||||
// .map(|d| match d {
|
||||
// SymbolicDim::Symbolic(_) => -1i64,
|
||||
// SymbolicDim::Fixed(x) => x as i64,
|
||||
// })
|
||||
// .collect();
|
||||
|
||||
// let output = session.output(0);
|
||||
// let output_info = output
|
||||
// .tensor_info()
|
||||
// .or_else(|| Error::ModelWrongInputType)?;
|
||||
// let output_name = output.name().to_string();
|
||||
// let output_shape: Vec<i64> = output_info
|
||||
// .symbolic_dims()
|
||||
// .map(|d| match d {
|
||||
// SymbolicDim::Symbolic(_) => -1i64,
|
||||
// SymbolicDim::Fixed(x) => x as i64,
|
||||
// })
|
||||
// .collect();
|
||||
|
||||
// let info = ModelInfo {
|
||||
// path: path.into(),
|
||||
// name: format!("{} exported by {}", metadata.graph_name(), metadata.producer_name()),
|
||||
// description: metadata.description().to_string(),
|
||||
// version: metadata.version(),
|
||||
// input_name,
|
||||
// input_shape,
|
||||
// output_name,
|
||||
// output_shape,
|
||||
// };
|
||||
|
||||
// Ok(OnnxModel {
|
||||
// info: Arc::new(info),
|
||||
// env,
|
||||
// options,
|
||||
// session,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// pub enum ModelCommand {
|
||||
// AddOnnxModel(usize, OnnxModelConfig),
|
||||
// Inference(usize, )
|
||||
// }
|
||||
|
||||
// struct ModelPool {
|
||||
// counter: usize,
|
||||
// sender: Sender<ModelCommand>,
|
||||
// }
|
||||
|
||||
// impl ModelPool {
|
||||
// pub fn add(&mut self, name: &str, file: Path) -> Result<&AsyncModel, Error> {
|
||||
|
||||
// }
|
||||
|
||||
// pub fn get_model(&self, name: &str) -> Option<&AsyncModel> {
|
||||
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// struct AsyncModel {
|
||||
|
||||
// }
|
||||
|
||||
|
||||
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn key_val_parse(str: &str) -> HashMap<String, usize> {
|
||||
let mut map = HashMap::new();
|
||||
if str.is_empty() {
|
||||
return map;
|
||||
}
|
||||
for key_val in str.split(',') {
|
||||
let mut iter = key_val.split('=');
|
||||
let key = iter.next().expect("no =");
|
||||
let val = iter
|
||||
.next()
|
||||
.expect("nothing after =")
|
||||
.parse()
|
||||
.expect("parse error");
|
||||
assert!(iter.next().is_none(), "more than 1 =");
|
||||
map.insert(key.to_owned(), val);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
/// Get the size of a tensor, substituting symbolic dimentions.
|
||||
fn tensor_size(
|
||||
info: &TensorInfo,
|
||||
named_sizes: &mut HashMap<String, usize>,
|
||||
) -> (OnnxTensorElementDataType, Vec<usize>) {
|
||||
let dims = info
|
||||
.symbolic_dims()
|
||||
.map(|d| match d {
|
||||
SymbolicDim::Symbolic(name) => {
|
||||
let name = name.to_str().unwrap();
|
||||
named_sizes.get(name).cloned().unwrap_or_else(|| {
|
||||
eprintln!("name {} not specified, setting to 1", name);
|
||||
named_sizes.insert(name.to_owned(), 1);
|
||||
1
|
||||
})
|
||||
}
|
||||
SymbolicDim::Fixed(x) => x,
|
||||
})
|
||||
.collect();
|
||||
(info.elem_type(), dims)
|
||||
}
|
||||
|
||||
fn tensor_mut(name: &str, elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn AsMut<Val>> {
|
||||
use OnnxTensorElementDataType::*;
|
||||
|
||||
println!("{:?} {} {:?}", elem_type, name, dims);
|
||||
|
||||
match elem_type {
|
||||
Float => Box::new(Tensor::<f32>::init(dims, 0.0).unwrap()),
|
||||
Int64 => Box::new(Tensor::<i64>::init(dims, 0).unwrap()),
|
||||
Int32 => Box::new(Tensor::<i32>::init(dims, 0).unwrap()),
|
||||
t => panic!("Unsupported type {:?}", t),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_image(filename: &str, height: usize, width: usize) -> Vec<f32> {
|
||||
let img = image::open(filename).unwrap()
|
||||
// .resize_exact(width as _, height as _, image::imageops::FilterType::Triangle)
|
||||
.into_rgb();
|
||||
|
||||
img.as_flat_samples()
|
||||
.to_vec()
|
||||
.samples
|
||||
.into_iter()
|
||||
.map(|p| (p as f32) / 255.0)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn tensor_with_size(
|
||||
name: &str,
|
||||
info: &TensorInfo,
|
||||
named_sizes: &mut HashMap<String, usize>,
|
||||
) -> Box<dyn AsRef<Val>> {
|
||||
let (ty, dims) = tensor_size(info, named_sizes);
|
||||
use OnnxTensorElementDataType::*;
|
||||
|
||||
println!("{:?} {} {:?}", ty, name, dims);
|
||||
match ty {
|
||||
Float => match name {
|
||||
"input" => Box::new(Tensor::<f32>::new(&dims, load_image("/home/andrey/Images/me.jpg", dims[2], dims[3])).unwrap()),
|
||||
_ => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
|
||||
},
|
||||
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
|
||||
Int32 => Box::new(Tensor::<i32>::init(&dims, 0).unwrap()),
|
||||
t => panic!("Unsupported type {:?}", t),
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let env = Env::new(LoggingLevel::Fatal, "test")?;
|
||||
let opt = Opt::from_args();
|
||||
|
||||
let mut so = SessionOptions::new()?;
|
||||
|
||||
// so.set_execution_mode(ExecutionMode::Parallel)?;
|
||||
// so.add_tensorrt(0);
|
||||
// so.add_cuda(0);
|
||||
// so.add_cpu(true);
|
||||
|
||||
let mut map = if let Some(dims) = &opt.dims {
|
||||
key_val_parse(dims)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
let batch_size = 4;
|
||||
|
||||
map.insert("batch_size".into(), batch_size);
|
||||
|
||||
for path in &opt.onnx {
|
||||
println!("model {:?}", path);
|
||||
let session = match Session::new(&env, path, &so) {
|
||||
Ok(sess) => sess,
|
||||
Err(err) => {
|
||||
eprintln!("error: {}\n", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let metadata = session.metadata();
|
||||
eprintln!("name: {}", metadata.producer_name());
|
||||
eprintln!("graph_name: {}", metadata.graph_name());
|
||||
eprintln!("domain: {}", metadata.domain());
|
||||
eprintln!("description: {}", metadata.description());
|
||||
|
||||
let mut input_names: Vec<OrtString> = vec![];
|
||||
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
|
||||
|
||||
for (i, input) in session.inputs().enumerate() {
|
||||
if let Some(tensor_info) = input.tensor_info() {
|
||||
input_names.push(input.name());
|
||||
input_tensors.push(tensor_with_size(input.name().as_str(), &tensor_info, &mut map));
|
||||
} else {
|
||||
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
|
||||
}
|
||||
}
|
||||
|
||||
let mut output_names: Vec<OrtString> = vec![];
|
||||
let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec<usize>)> = vec![];
|
||||
|
||||
for (i, output) in session.outputs().enumerate() {
|
||||
if let Some(tensor_info) = output.tensor_info() {
|
||||
output_names.push(output.name());
|
||||
output_sizes.push(tensor_size(&tensor_info, &mut map));
|
||||
} else {
|
||||
println!(
|
||||
"output {}: {:?} {:?}",
|
||||
i,
|
||||
&*output.name(),
|
||||
output.onnx_type()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
|
||||
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
|
||||
let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect();
|
||||
|
||||
let ro = RunOptions::new();
|
||||
|
||||
let before = Instant::now();
|
||||
|
||||
let mut res = session
|
||||
.run_raw(&ro, &in_names, &in_vals[..], &out_names)
|
||||
.expect("run");
|
||||
|
||||
let tensor = match res.pop().unwrap().as_tensor::<f32>() {
|
||||
Ok(t) => t,
|
||||
_ => panic!("something went wrong")
|
||||
};
|
||||
|
||||
println!("[{:?}] {}", tensor.dims(), before.elapsed().as_millis())
|
||||
|
||||
// println!("{:?}", out_vals[0].as_slice::<f32>());
|
||||
// let out_vals[0].into();
|
||||
// let total: Duration = times.iter().sum()
|
||||
// let avg = total / (times.len() as u32);
|
||||
// eprintln!("worker {} avg time: {:.2} ms", i, avg.as_secs_f64() * 1e3);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -77,6 +77,7 @@ fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn A
|
||||
match elem_type {
|
||||
Float => Box::new(Tensor::<f32>::init(dims, 0.0).unwrap()),
|
||||
Int64 => Box::new(Tensor::<i64>::init(dims, 0).unwrap()),
|
||||
Int32 => Box::new(Tensor::<i32>::init(dims, 0).unwrap()),
|
||||
t => panic!("Unsupported type {:?}", t),
|
||||
}
|
||||
}
|
||||
@ -90,6 +91,7 @@ fn tensor_with_size(
|
||||
match ty {
|
||||
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
|
||||
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
|
||||
Int32 => Box::new(Tensor::<i32>::init(&dims, 0).unwrap()),
|
||||
t => panic!("Unsupported type {:?}", t),
|
||||
}
|
||||
}
|
||||
|
57
src/lib.rs
57
src/lib.rs
@ -8,7 +8,7 @@ pub mod sys;
|
||||
// Re-export enums
|
||||
pub use sys::{
|
||||
AllocatorType, ErrorCode, ExecutionMode, GraphOptimizationLevel, LoggingLevel, MemType,
|
||||
OnnxTensorElementDataType, OnnxType,
|
||||
OnnxTensorElementDataType, OnnxType,
|
||||
};
|
||||
|
||||
#[macro_use]
|
||||
@ -160,6 +160,41 @@ impl SessionOptions {
|
||||
let raw = call!(@unsafe @ptr CreateSessionOptions)?;
|
||||
Ok(SessionOptions { raw })
|
||||
}
|
||||
|
||||
pub fn add_cpu(&self, use_arena: bool) {
|
||||
let so = self.raw;
|
||||
let status = unsafe {
|
||||
crate::sys::SessionOptionsAppendExecutionProvider_CPU(so, if use_arena { 1 } else { 0 })
|
||||
};
|
||||
|
||||
if !status.is_null() {
|
||||
panic!("!!!");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_cuda(&self, device_id: i32) {
|
||||
|
||||
let so = self.raw;
|
||||
let status = unsafe {
|
||||
crate::sys::SessionOptionsAppendExecutionProvider_CUDA(so, device_id)
|
||||
};
|
||||
|
||||
if !status.is_null() {
|
||||
panic!("!!!");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_tensorrt(&self, device_id: i32) {
|
||||
|
||||
let so = self.raw;
|
||||
let status = unsafe {
|
||||
crate::sys::SessionOptionsAppendExecutionProvider_Tensorrt(so, device_id)
|
||||
};
|
||||
|
||||
if !status.is_null() {
|
||||
panic!("!!!");
|
||||
}
|
||||
}
|
||||
|
||||
options! {
|
||||
fn enable_mem_pattern() { EnableMemPattern };
|
||||
@ -171,6 +206,8 @@ impl SessionOptions {
|
||||
fn disable_cpu_mem_arena() { DisableCpuMemArena };
|
||||
fn set_session_log_id(log_id: &str) { SetSessionLogId };
|
||||
fn en_prof(path: &CStr | .as_ptr()) { EnableProfiling };
|
||||
fn set_execution_mode(mode: ExecutionMode) { SetSessionExecutionMode };
|
||||
|
||||
fn set_session_log_verbosity_level(verbosity_level: i32) { SetSessionLogVerbosityLevel };
|
||||
fn set_session_log_severity_level(severity_level: i32) { SetSessionLogSeverityLevel };
|
||||
fn set_session_graph_optimization_level(graph_optimization_level: GraphOptimizationLevel)
|
||||
@ -422,7 +459,7 @@ impl Session {
|
||||
assert_eq!(input_names.len(), inputs.len());
|
||||
|
||||
let output_size = output_names.len() as u64;
|
||||
let mut raw_outputs: *mut sys::Value = ptr::null_mut();
|
||||
let mut raw_outputs: Box<[*mut sys::Value]> = (0..output_size).map(|_| ptr::null_mut()).collect();
|
||||
call!(@unsafe
|
||||
Run,
|
||||
self.raw,
|
||||
@ -432,17 +469,15 @@ impl Session {
|
||||
inputs.len() as u64,
|
||||
cstr_ptrs(output_names).as_ptr(),
|
||||
output_size,
|
||||
&mut raw_outputs
|
||||
raw_outputs.as_mut_ptr() as *mut *mut sys::Value
|
||||
)?;
|
||||
|
||||
unsafe {
|
||||
Ok(
|
||||
std::slice::from_raw_parts(&raw_outputs, output_size as usize)
|
||||
.iter()
|
||||
.map(|v| Value { raw: *v })
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
Ok(
|
||||
raw_outputs
|
||||
.into_iter()
|
||||
.map(|v| Value { raw: *v })
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
13
src/value.rs
13
src/value.rs
@ -39,6 +39,19 @@ impl Val {
|
||||
call!(@unsafe @ptr @expect GetTensorMutableData, self.raw())
|
||||
}
|
||||
|
||||
pub fn as_slice<T: OrtType>(&self) -> Option<&[T]> {
|
||||
let st = self.shape_and_type();
|
||||
if st.elem_type() == T::onnx_type() {
|
||||
let len = st.dims().into_iter().map(|x|x as usize).product();
|
||||
let data = self.tensor_data();
|
||||
Some(unsafe {
|
||||
&*std::ptr::slice_from_raw_parts(data as *mut T, len)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape_and_type(&self) -> TensorTypeAndShapeInfo {
|
||||
let raw = call!(@unsafe @ptr @expect GetTensorTypeAndShape, self.raw());
|
||||
TensorTypeAndShapeInfo { raw }
|
||||
|
Loading…
Reference in New Issue
Block a user