Some fixes

This commit is contained in:
Andrey Tkachenko 2020-07-03 16:30:08 +04:00
parent 08fb5dbf8c
commit 25785a1ada
8 changed files with 638 additions and 12 deletions

208
Cargo.lock generated
View File

@ -1,5 +1,11 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "adler32"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567b077b825e468cc974f0020d4082ee6e03132512f207ef1a02fd5d00d1f32d"
[[package]]
name = "aho-corasick"
version = "0.7.10"
@ -65,6 +71,18 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "bytemuck"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37fa13df2292ecb479ec23aa06f4507928bef07839be9ef15281411076629431"
[[package]]
name = "byteorder"
version = "1.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de"
[[package]]
name = "cc"
version = "1.0.54"
@ -112,6 +130,21 @@ dependencies = [
"vec_map",
]
[[package]]
name = "color_quant"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dbbb57365263e881e805dc77d94697c9118fd94d8da011240555aa7b23445bd"
[[package]]
name = "crc32fast"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba125de2af0df55319f41944744ad91c71113bf74a4646efff39afe1f6842db1"
dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam"
version = "0.7.3"
@ -183,6 +216,22 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "deflate"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7e5d2a2273fed52a7f947ee55b092c4057025d7a3e04e5ecdbd25d6c3fb1bd7"
dependencies = [
"adler32",
"byteorder",
]
[[package]]
name = "either"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
[[package]]
name = "env_logger"
version = "0.7.1"
@ -196,6 +245,16 @@ dependencies = [
"termcolor",
]
[[package]]
name = "gif"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "471d90201b3b223f3451cd4ad53e34295f16a1df17b1edf3736d47761c3981af"
dependencies = [
"color_quant",
"lzw",
]
[[package]]
name = "glob"
version = "0.3.0"
@ -229,6 +288,34 @@ dependencies = [
"quick-error",
]
[[package]]
name = "image"
version = "0.23.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b0553fec6407d63fe2975b794dfb099f3f790bdc958823851af37b26404ab4"
dependencies = [
"bytemuck",
"byteorder",
"gif",
"jpeg-decoder",
"num-iter",
"num-rational",
"num-traits",
"png",
"scoped_threadpool",
"tiff",
]
[[package]]
name = "jpeg-decoder"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b47b4c4e017b01abdc5bcc126d2d1002e5a75bbe3ce73f9f4f311a916363704"
dependencies = [
"byteorder",
"rayon",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -266,6 +353,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "lzw"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d947cbb889ed21c2a84be6ffbaebf5b4e0f4340638cba0444907e38b56be084"
[[package]]
name = "maybe-uninit"
version = "2.0.0"
@ -287,6 +380,15 @@ dependencies = [
"autocfg",
]
[[package]]
name = "miniz_oxide"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "791daaae1ed6889560f8c4359194f56648355540573244a5448a83ba1ecc7435"
dependencies = [
"adler32",
]
[[package]]
name = "nom"
version = "5.1.1"
@ -297,6 +399,57 @@ dependencies = [
"version_check",
]
[[package]]
name = "num-integer"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a6e6b7c748f995c4c29c5f5ae0248536e04a5739927c74ec0fa564805094b9f"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5b4d7360f362cfb50dde8143501e6940b22f644be75a4cc90b2d81968908138"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "onnxruntime"
version = "0.1.0"
@ -304,6 +457,7 @@ dependencies = [
"bindgen",
"crossbeam",
"heck",
"image",
"lazy_static",
"structopt",
]
@ -314,6 +468,18 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "png"
version = "0.16.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34ccdd66f6fe4b2433b07e4728e9a013e43233120427046e93ceb709c3a439bf"
dependencies = [
"bitflags",
"crc32fast",
"deflate",
"miniz_oxide",
]
[[package]]
name = "proc-macro-error"
version = "1.0.2"
@ -364,6 +530,31 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rayon"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f02856753d04e03e26929f820d0a0a337ebe71f849801eea335d464b349080"
dependencies = [
"autocfg",
"crossbeam-deque",
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e92e15d89083484e11353891f1af602cc661426deb9564c298b270c726973280"
dependencies = [
"crossbeam-deque",
"crossbeam-queue",
"crossbeam-utils",
"lazy_static",
"num_cpus",
]
[[package]]
name = "regex"
version = "1.3.7"
@ -388,6 +579,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "scoped_threadpool"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8"
[[package]]
name = "scopeguard"
version = "1.1.0"
@ -479,6 +676,17 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "tiff"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f3b8a87c4da944c3f27e5943289171ac71a6150a79ff6bacfff06d159dfff2f"
dependencies = [
"byteorder",
"lzw",
"miniz_oxide",
]
[[package]]
name = "unicode-segmentation"
version = "1.6.0"

View File

@ -10,6 +10,9 @@ lazy_static = "1.4"
structopt = "0.3.14"
crossbeam = "0.7.3"
[dev-dependencies]
image = "0.23.6"
[build-dependencies]
bindgen = "0.53.2"
heck = "0.3"

View File

@ -27,6 +27,10 @@ fn main() {
.clang_args(clang_args)
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.whitelist_function("OrtGetApiBase")
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_CPU")
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_OpenVINO")
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_CUDA")
.whitelist_function("OrtSessionOptionsAppendExecutionProvider_Tensorrt")
.whitelist_var("ORT_.*")
.whitelist_recursively(true)
.blacklist_type("__int64_t")

View File

@ -1 +1,6 @@
#include <onnxruntime_c_api.h>
#include <onnxruntime/core/session/onnxruntime_c_api.h>
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) ORT_ALL_ARGS_NONNULL;
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, const char* device_id);

356
examples/yolo.rs Normal file
View File

@ -0,0 +1,356 @@
use std::ffi::CStr;
use std::time::{Duration, Instant};
use onnxruntime::*;
use structopt::{clap, StructOpt};
use std::path::PathBuf;
#[structopt(
name = "run",
about = "Run a benchmark on an onnx model. Each worker runs the model in a loop in its own
thead. Once done it will print the average time to run the model.",
setting = clap::AppSettings::ColoredHelp
)]
#[derive(StructOpt)]
struct Opt {
/// The path to the onnx files to benchmark
onnx: Vec<String>,
/// A comma separated list of symbolic_dimension=value. If a symbolic dimension is not
/// specified, 1 will be used.
#[structopt(long)]
dims: Option<String>,
/// The number of worker threads to spawn
#[structopt(long, default_value = "1")]
workers: usize,
/// The number of runs each worker will
#[structopt(long, default_value = "1")]
runs: usize,
}
// #[derive(Debug, Clone)]
// struct ModelInfo {
// pub path: PathBuf,
// pub name: String,
// pub description: String,
// pub version: i64,
// pub input_shape: Vec<i64>,
// pub output_shape: Vec<i64>,
// pub input_name: String,
// pub output_name: String,
// }
// struct OnnxModelConfig {
// pub model_path: PathBuf,
// pub parallel: bool,
// pub use_cuda: bool,
// pub use_cpu: bool,
// pub use_tensorrt: bool,
// pub cpu_use_arena:bool,
// pub tensorrt_device_index: i32,
// pub cuda_device_index: i32,
// pub logging_level: LoggingLevel,
// pub logging_id: String,
// }
// struct OnnxModel {
// info: Arc<ModelInfo>,
// env: Env,
// options: SessionOptions,
// session: Session,
// }
// impl OnnxModel {
// pub fn new(config: OnnxModelConfig) -> Result<Self, Error> {
// let env = Env::new(config.logging_level, &config.logging_id)?;
// let mut options = SessionOptions::new()?;
// if config.parallel {
// options.set_execution_mode(ExecutionMode::Parallel)?;
// }
// if config.use_cpu {
// options.add_cpu(config.cpu_use_arena);
// }
// if config.use_cuda {
// options.add_cuda(config.cuda_device_index);
// }
// if config.use_tensorrt {
// options.add_tensorrt(config.tensorrt_device_index);
// }
// let session = Session::new(&env, config.model_path.to_str().unwrap(), &options)?;
// let metadata = session.metadata();
// let inputs = session.inputs();
// let outputs = session.outputs();
// if inputs.len() != 1 {
// return Err(Error::ModelWrongInputsCount);
// }
// if outputs.len() != 1 {
// return Err(Error::ModelWrongOutputsCount);
// }
// let input = session.input(0);
// let input_info = input
// .tensor_info()
// .or_else(|| Error::ModelWrongInputType)?;
// let input_name = input.name().to_string();
// let input_shape: Vec<i64> = input_info
// .symbolic_dims()
// .map(|d| match d {
// SymbolicDim::Symbolic(_) => -1i64,
// SymbolicDim::Fixed(x) => x as i64,
// })
// .collect();
// let output = session.output(0);
// let output_info = output
// .tensor_info()
// .or_else(|| Error::ModelWrongInputType)?;
// let output_name = output.name().to_string();
// let output_shape: Vec<i64> = output_info
// .symbolic_dims()
// .map(|d| match d {
// SymbolicDim::Symbolic(_) => -1i64,
// SymbolicDim::Fixed(x) => x as i64,
// })
// .collect();
// let info = ModelInfo {
// path: path.into(),
// name: format!("{} exported by {}", metadata.graph_name(), metadata.producer_name()),
// description: metadata.description().to_string(),
// version: metadata.version(),
// input_name,
// input_shape,
// output_name,
// output_shape,
// };
// Ok(OnnxModel {
// info: Arc::new(info),
// env,
// options,
// session,
// })
// }
// }
// pub enum ModelCommand {
// AddOnnxModel(usize, OnnxModelConfig),
// Inference(usize, )
// }
// struct ModelPool {
// counter: usize,
// sender: Sender<ModelCommand>,
// }
// impl ModelPool {
// pub fn add(&mut self, name: &str, file: Path) -> Result<&AsyncModel, Error> {
// }
// pub fn get_model(&self, name: &str) -> Option<&AsyncModel> {
// }
// }
// struct AsyncModel {
// }
use std::collections::HashMap;
fn key_val_parse(str: &str) -> HashMap<String, usize> {
let mut map = HashMap::new();
if str.is_empty() {
return map;
}
for key_val in str.split(',') {
let mut iter = key_val.split('=');
let key = iter.next().expect("no =");
let val = iter
.next()
.expect("nothing after =")
.parse()
.expect("parse error");
assert!(iter.next().is_none(), "more than 1 =");
map.insert(key.to_owned(), val);
}
map
}
/// Get the size of a tensor, substituting symbolic dimentions.
fn tensor_size(
info: &TensorInfo,
named_sizes: &mut HashMap<String, usize>,
) -> (OnnxTensorElementDataType, Vec<usize>) {
let dims = info
.symbolic_dims()
.map(|d| match d {
SymbolicDim::Symbolic(name) => {
let name = name.to_str().unwrap();
named_sizes.get(name).cloned().unwrap_or_else(|| {
eprintln!("name {} not specified, setting to 1", name);
named_sizes.insert(name.to_owned(), 1);
1
})
}
SymbolicDim::Fixed(x) => x,
})
.collect();
(info.elem_type(), dims)
}
fn tensor_mut(name: &str, elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn AsMut<Val>> {
use OnnxTensorElementDataType::*;
println!("{:?} {} {:?}", elem_type, name, dims);
match elem_type {
Float => Box::new(Tensor::<f32>::init(dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(dims, 0).unwrap()),
Int32 => Box::new(Tensor::<i32>::init(dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
fn load_image(filename: &str, height: usize, width: usize) -> Vec<f32> {
let img = image::open(filename).unwrap()
// .resize_exact(width as _, height as _, image::imageops::FilterType::Triangle)
.into_rgb();
img.as_flat_samples()
.to_vec()
.samples
.into_iter()
.map(|p| (p as f32) / 255.0)
.collect()
}
fn tensor_with_size(
name: &str,
info: &TensorInfo,
named_sizes: &mut HashMap<String, usize>,
) -> Box<dyn AsRef<Val>> {
let (ty, dims) = tensor_size(info, named_sizes);
use OnnxTensorElementDataType::*;
println!("{:?} {} {:?}", ty, name, dims);
match ty {
Float => match name {
"input" => Box::new(Tensor::<f32>::new(&dims, load_image("/home/andrey/Images/me.jpg", dims[2], dims[3])).unwrap()),
_ => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
},
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
Int32 => Box::new(Tensor::<i32>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
fn main() -> Result<()> {
let env = Env::new(LoggingLevel::Fatal, "test")?;
let opt = Opt::from_args();
let mut so = SessionOptions::new()?;
// so.set_execution_mode(ExecutionMode::Parallel)?;
// so.add_tensorrt(0);
// so.add_cuda(0);
// so.add_cpu(true);
let mut map = if let Some(dims) = &opt.dims {
key_val_parse(dims)
} else {
HashMap::new()
};
let batch_size = 4;
map.insert("batch_size".into(), batch_size);
for path in &opt.onnx {
println!("model {:?}", path);
let session = match Session::new(&env, path, &so) {
Ok(sess) => sess,
Err(err) => {
eprintln!("error: {}\n", err);
continue;
}
};
let metadata = session.metadata();
eprintln!("name: {}", metadata.producer_name());
eprintln!("graph_name: {}", metadata.graph_name());
eprintln!("domain: {}", metadata.domain());
eprintln!("description: {}", metadata.description());
let mut input_names: Vec<OrtString> = vec![];
let mut input_tensors: Vec<Box<dyn AsRef<Val>>> = vec![];
for (i, input) in session.inputs().enumerate() {
if let Some(tensor_info) = input.tensor_info() {
input_names.push(input.name());
input_tensors.push(tensor_with_size(input.name().as_str(), &tensor_info, &mut map));
} else {
println!("input {}: {:?} {:?}", i, &*input.name(), input.onnx_type());
}
}
let mut output_names: Vec<OrtString> = vec![];
let mut output_sizes: Vec<(OnnxTensorElementDataType, Vec<usize>)> = vec![];
for (i, output) in session.outputs().enumerate() {
if let Some(tensor_info) = output.tensor_info() {
output_names.push(output.name());
output_sizes.push(tensor_size(&tensor_info, &mut map));
} else {
println!(
"output {}: {:?} {:?}",
i,
&*output.name(),
output.onnx_type()
);
}
}
let in_names: Vec<&CStr> = input_names.iter().map(|x| x.as_c_str()).collect();
let in_vals: Vec<&Val> = input_tensors.iter().map(|x| x.as_ref().as_ref()).collect();
let out_names: Vec<&CStr> = output_names.iter().map(|x| x.as_c_str()).collect();
let ro = RunOptions::new();
let before = Instant::now();
let mut res = session
.run_raw(&ro, &in_names, &in_vals[..], &out_names)
.expect("run");
let tensor = match res.pop().unwrap().as_tensor::<f32>() {
Ok(t) => t,
_ => panic!("something went wrong")
};
println!("[{:?}] {}", tensor.dims(), before.elapsed().as_millis())
// println!("{:?}", out_vals[0].as_slice::<f32>());
// let out_vals[0].into();
// let total: Duration = times.iter().sum()
// let avg = total / (times.len() as u32);
// eprintln!("worker {} avg time: {:.2} ms", i, avg.as_secs_f64() * 1e3);
}
Ok(())
}

View File

@ -77,6 +77,7 @@ fn tensor_mut(elem_type: OnnxTensorElementDataType, dims: &[usize]) -> Box<dyn A
match elem_type {
Float => Box::new(Tensor::<f32>::init(dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(dims, 0).unwrap()),
Int32 => Box::new(Tensor::<i32>::init(dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}
@ -90,6 +91,7 @@ fn tensor_with_size(
match ty {
Float => Box::new(Tensor::<f32>::init(&dims, 0.0).unwrap()),
Int64 => Box::new(Tensor::<i64>::init(&dims, 0).unwrap()),
Int32 => Box::new(Tensor::<i32>::init(&dims, 0).unwrap()),
t => panic!("Unsupported type {:?}", t),
}
}

View File

@ -8,7 +8,7 @@ pub mod sys;
// Re-export enums
pub use sys::{
AllocatorType, ErrorCode, ExecutionMode, GraphOptimizationLevel, LoggingLevel, MemType,
OnnxTensorElementDataType, OnnxType,
OnnxTensorElementDataType, OnnxType,
};
#[macro_use]
@ -160,6 +160,41 @@ impl SessionOptions {
let raw = call!(@unsafe @ptr CreateSessionOptions)?;
Ok(SessionOptions { raw })
}
pub fn add_cpu(&self, use_arena: bool) {
let so = self.raw;
let status = unsafe {
crate::sys::SessionOptionsAppendExecutionProvider_CPU(so, if use_arena { 1 } else { 0 })
};
if !status.is_null() {
panic!("!!!");
}
}
pub fn add_cuda(&self, device_id: i32) {
let so = self.raw;
let status = unsafe {
crate::sys::SessionOptionsAppendExecutionProvider_CUDA(so, device_id)
};
if !status.is_null() {
panic!("!!!");
}
}
pub fn add_tensorrt(&self, device_id: i32) {
let so = self.raw;
let status = unsafe {
crate::sys::SessionOptionsAppendExecutionProvider_Tensorrt(so, device_id)
};
if !status.is_null() {
panic!("!!!");
}
}
options! {
fn enable_mem_pattern() { EnableMemPattern };
@ -171,6 +206,8 @@ impl SessionOptions {
fn disable_cpu_mem_arena() { DisableCpuMemArena };
fn set_session_log_id(log_id: &str) { SetSessionLogId };
fn en_prof(path: &CStr | .as_ptr()) { EnableProfiling };
fn set_execution_mode(mode: ExecutionMode) { SetSessionExecutionMode };
fn set_session_log_verbosity_level(verbosity_level: i32) { SetSessionLogVerbosityLevel };
fn set_session_log_severity_level(severity_level: i32) { SetSessionLogSeverityLevel };
fn set_session_graph_optimization_level(graph_optimization_level: GraphOptimizationLevel)
@ -422,7 +459,7 @@ impl Session {
assert_eq!(input_names.len(), inputs.len());
let output_size = output_names.len() as u64;
let mut raw_outputs: *mut sys::Value = ptr::null_mut();
let mut raw_outputs: Box<[*mut sys::Value]> = (0..output_size).map(|_| ptr::null_mut()).collect();
call!(@unsafe
Run,
self.raw,
@ -432,17 +469,15 @@ impl Session {
inputs.len() as u64,
cstr_ptrs(output_names).as_ptr(),
output_size,
&mut raw_outputs
raw_outputs.as_mut_ptr() as *mut *mut sys::Value
)?;
unsafe {
Ok(
std::slice::from_raw_parts(&raw_outputs, output_size as usize)
.iter()
.map(|v| Value { raw: *v })
.collect(),
)
}
Ok(
raw_outputs
.into_iter()
.map(|v| Value { raw: *v })
.collect()
)
}
}

View File

@ -39,6 +39,19 @@ impl Val {
call!(@unsafe @ptr @expect GetTensorMutableData, self.raw())
}
pub fn as_slice<T: OrtType>(&self) -> Option<&[T]> {
let st = self.shape_and_type();
if st.elem_type() == T::onnx_type() {
let len = st.dims().into_iter().map(|x|x as usize).product();
let data = self.tensor_data();
Some(unsafe {
&*std::ptr::slice_from_raw_parts(data as *mut T, len)
})
} else {
None
}
}
pub fn shape_and_type(&self) -> TensorTypeAndShapeInfo {
let raw = call!(@unsafe @ptr @expect GetTensorTypeAndShape, self.raw());
TensorTypeAndShapeInfo { raw }