Add TensorView and TensorViewMut

This commit is contained in:
Christopher Chalmers 2020-05-29 00:15:46 +01:00
parent d13a7ad0ba
commit 4d1f3edd5d
2 changed files with 113 additions and 4 deletions

View File

@ -18,7 +18,7 @@ pub use allocator::Allocator;
// note that this be come after the macro definitions (in api)
mod value;
pub use value::{OrtType, Tensor, TensorInfo, Val};
pub use value::{OrtType, Tensor, TensorInfo, TensorView, TensorViewMut, Val};
macro_rules! ort_type {
($t:ident, $r:ident) => {
@ -616,7 +616,7 @@ mod tests {
let ro = RunOptions::new();
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let input_tensor = Tensor::new(&[3,2], input_data)?;
let input_tensor = Tensor::new(&[3, 2], input_data)?;
// immutable version
let output = session.run_raw(&ro, &[&in_name], &[input_tensor.value()], &[&out_name])?;

View File

@ -1,5 +1,5 @@
use std::fmt;
use std::ffi::c_void;
use std::fmt;
use std::ops::{Deref, DerefMut};
use crate::*;
@ -174,7 +174,6 @@ impl<T> SymbolicDim<T> {
}
}
impl<T: fmt::Debug> fmt::Debug for SymbolicDim<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@ -334,3 +333,113 @@ impl<T> std::convert::AsMut<Val> for Tensor<T> {
&mut self.val
}
}
/// A mutable view of a slice as a tensor.
pub struct TensorView<'a, T> {
slice: &'a [T],
val: Value,
shape: Vec<i64>,
}
impl<'a, T> std::ops::Deref for TensorView<'a, T> {
type Target = [T];
fn deref(&self) -> &[T] {
&self.slice
}
}
impl<'a, T> std::convert::AsRef<Val> for TensorView<'a, T> {
fn as_ref(&self) -> &Val {
&self.val
}
}
impl<'a, T> TensorView<'a, T> {
pub fn new(shape: &[usize], slice: &'a [T]) -> TensorView<'a, T>
where
T: OrtType,
{
assert!(shape.iter().product::<usize>() == slice.len());
let shape: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
let raw = call!(@unsafe @ptr @expect
CreateTensorWithDataAsOrtValue,
CPU_ARENA.raw,
slice.as_ptr() as *const _ as *mut _,
(slice.len() * std::mem::size_of::<T>()) as u64,
shape.as_ptr(),
shape.len() as u64,
T::onnx_type()
);
TensorView {
slice,
val: Value { raw },
shape,
}
}
pub fn shape(&self) -> &[i64] {
&self.shape
}
}
/// A mutable view of a slice as a tensor.
pub struct TensorViewMut<'a, T> {
slice: &'a mut [T],
val: Value,
shape: Vec<i64>,
}
impl<'a, T> std::ops::Deref for TensorViewMut<'a, T> {
type Target = [T];
fn deref(&self) -> &[T] {
&self.slice
}
}
impl<'a, T> std::ops::DerefMut for TensorViewMut<'a, T> {
fn deref_mut(&mut self) -> &mut [T] {
&mut self.slice
}
}
impl<'a, T> std::convert::AsRef<Val> for TensorViewMut<'a, T> {
fn as_ref(&self) -> &Val {
&self.val
}
}
impl<'a, T> std::convert::AsMut<Val> for TensorViewMut<'a, T> {
fn as_mut(&mut self) -> &mut Val {
&mut self.val
}
}
impl<'a, T> TensorViewMut<'a, T> {
pub fn new(shape: &[usize], slice: &'a mut [T]) -> TensorViewMut<'a, T>
where
T: OrtType,
{
assert!(shape.iter().product::<usize>() == slice.len());
let shape: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
let raw = call!(@unsafe @ptr @expect
CreateTensorWithDataAsOrtValue,
CPU_ARENA.raw,
slice.as_mut_ptr() as *mut _,
(slice.len() * std::mem::size_of::<T>()) as u64,
shape.as_ptr(),
shape.len() as u64,
T::onnx_type()
);
TensorViewMut {
slice,
val: Value { raw },
shape,
}
}
pub fn shape(&self) -> &[i64] {
&self.shape
}
}