Add TensorView and TensorViewMut
This commit is contained in:
parent
d13a7ad0ba
commit
4d1f3edd5d
@ -18,7 +18,7 @@ pub use allocator::Allocator;
|
||||
|
||||
// note that this be come after the macro definitions (in api)
|
||||
mod value;
|
||||
pub use value::{OrtType, Tensor, TensorInfo, Val};
|
||||
pub use value::{OrtType, Tensor, TensorInfo, TensorView, TensorViewMut, Val};
|
||||
|
||||
macro_rules! ort_type {
|
||||
($t:ident, $r:ident) => {
|
||||
@ -616,7 +616,7 @@ mod tests {
|
||||
let ro = RunOptions::new();
|
||||
|
||||
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let input_tensor = Tensor::new(&[3,2], input_data)?;
|
||||
let input_tensor = Tensor::new(&[3, 2], input_data)?;
|
||||
|
||||
// immutable version
|
||||
let output = session.run_raw(&ro, &[&in_name], &[input_tensor.value()], &[&out_name])?;
|
||||
|
113
src/value.rs
113
src/value.rs
@ -1,5 +1,5 @@
|
||||
use std::fmt;
|
||||
use std::ffi::c_void;
|
||||
use std::fmt;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use crate::*;
|
||||
@ -174,7 +174,6 @@ impl<T> SymbolicDim<T> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<T: fmt::Debug> fmt::Debug for SymbolicDim<T> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
@ -334,3 +333,113 @@ impl<T> std::convert::AsMut<Val> for Tensor<T> {
|
||||
&mut self.val
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable view of a slice as a tensor.
|
||||
pub struct TensorView<'a, T> {
|
||||
slice: &'a [T],
|
||||
val: Value,
|
||||
shape: Vec<i64>,
|
||||
}
|
||||
|
||||
impl<'a, T> std::ops::Deref for TensorView<'a, T> {
|
||||
type Target = [T];
|
||||
|
||||
fn deref(&self) -> &[T] {
|
||||
&self.slice
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::convert::AsRef<Val> for TensorView<'a, T> {
|
||||
fn as_ref(&self) -> &Val {
|
||||
&self.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> TensorView<'a, T> {
|
||||
pub fn new(shape: &[usize], slice: &'a [T]) -> TensorView<'a, T>
|
||||
where
|
||||
T: OrtType,
|
||||
{
|
||||
assert!(shape.iter().product::<usize>() == slice.len());
|
||||
let shape: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
|
||||
let raw = call!(@unsafe @ptr @expect
|
||||
CreateTensorWithDataAsOrtValue,
|
||||
CPU_ARENA.raw,
|
||||
slice.as_ptr() as *const _ as *mut _,
|
||||
(slice.len() * std::mem::size_of::<T>()) as u64,
|
||||
shape.as_ptr(),
|
||||
shape.len() as u64,
|
||||
T::onnx_type()
|
||||
);
|
||||
TensorView {
|
||||
slice,
|
||||
val: Value { raw },
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &[i64] {
|
||||
&self.shape
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable view of a slice as a tensor.
|
||||
pub struct TensorViewMut<'a, T> {
|
||||
slice: &'a mut [T],
|
||||
val: Value,
|
||||
shape: Vec<i64>,
|
||||
}
|
||||
|
||||
impl<'a, T> std::ops::Deref for TensorViewMut<'a, T> {
|
||||
type Target = [T];
|
||||
|
||||
fn deref(&self) -> &[T] {
|
||||
&self.slice
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::ops::DerefMut for TensorViewMut<'a, T> {
|
||||
fn deref_mut(&mut self) -> &mut [T] {
|
||||
&mut self.slice
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::convert::AsRef<Val> for TensorViewMut<'a, T> {
|
||||
fn as_ref(&self) -> &Val {
|
||||
&self.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::convert::AsMut<Val> for TensorViewMut<'a, T> {
|
||||
fn as_mut(&mut self) -> &mut Val {
|
||||
&mut self.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> TensorViewMut<'a, T> {
|
||||
pub fn new(shape: &[usize], slice: &'a mut [T]) -> TensorViewMut<'a, T>
|
||||
where
|
||||
T: OrtType,
|
||||
{
|
||||
assert!(shape.iter().product::<usize>() == slice.len());
|
||||
let shape: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
|
||||
let raw = call!(@unsafe @ptr @expect
|
||||
CreateTensorWithDataAsOrtValue,
|
||||
CPU_ARENA.raw,
|
||||
slice.as_mut_ptr() as *mut _,
|
||||
(slice.len() * std::mem::size_of::<T>()) as u64,
|
||||
shape.as_ptr(),
|
||||
shape.len() as u64,
|
||||
T::onnx_type()
|
||||
);
|
||||
TensorViewMut {
|
||||
slice,
|
||||
val: Value { raw },
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &[i64] {
|
||||
&self.shape
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user