Add resize tensor function

This commit is contained in:
Christopher Chalmers 2020-05-22 09:27:14 +01:00
parent f20f294de5
commit 8abd94d828

View File

@ -192,6 +192,18 @@ impl<T: OrtType> Tensor<T> {
pub fn value_mut(&mut self) -> &mut Value {
&mut self.val
}
// must be owned or will panic, don't give it negative dims
pub fn resize(&mut self, dims: Vec<i64>)
where T: Clone + Default
{
let len = dims.iter().product::<i64>();
let owned = self.owned.as_mut().expect("Tensor::resize not owned");
owned.resize(len as usize, T::default());
unsafe {
self.value_mut().shape_and_type().set_dims(&dims);
}
}
}
impl<T> std::borrow::Borrow<[T]> for Tensor<T> {