From 63d0428e32a737f324724d83c7cd1d62fd98185b Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Tue, 16 Jul 2019 00:10:36 +0400 Subject: [PATCH] Got rid of AbstractLayer --- yarnn-examples/mnist/src/main.rs | 6 +- yarnn-models/mnist/src/lib.rs | 5 +- yarnn-models/vgg16/src/lib.rs | 1 - yarnn/src/layer.rs | 138 +++++++++--------------------- yarnn/src/layers/avgpool2d.rs | 66 ++++++++------ yarnn/src/layers/chain.rs | 92 ++++++++++++++------ yarnn/src/layers/conv2d.rs | 77 ++++++++++------- yarnn/src/layers/flatten.rs | 46 ++++++---- yarnn/src/layers/linear.rs | 79 ++++++++++------- yarnn/src/layers/maxpool2d.rs | 59 ++++++++----- yarnn/src/layers/relu.rs | 47 ++++++---- yarnn/src/layers/sigmoid.rs | 49 +++++++---- yarnn/src/layers/softmax.rs | 49 +++++++---- yarnn/src/layers/zeropadding2d.rs | 42 +++++---- yarnn/src/macros.rs | 126 +++++++++++++++++---------- yarnn/src/optimizer.rs | 7 +- yarnn/src/tensor.rs | 4 +- 17 files changed, 515 insertions(+), 378 deletions(-) diff --git a/yarnn-examples/mnist/src/main.rs b/yarnn-examples/mnist/src/main.rs index 2903ec5..6bf9fe7 100644 --- a/yarnn-examples/mnist/src/main.rs +++ b/yarnn-examples/mnist/src/main.rs @@ -43,6 +43,7 @@ fn main() { // let mut model = MnistDenseModel::new(28, 28, 1); let mut model = MnistConvModel::new(28, 28, 1); + model.init(&backend); println!("{}", &model); @@ -52,7 +53,7 @@ fn main() { let loss = CrossEntropyLoss::new(); let Mnist { trn_img, trn_lbl, tst_img, tst_lbl, .. } = MnistBuilder::new() - .base_path("../../dataset/mnist") + .base_path("./datasets/mnist") .label_format_digit() .finalize(); @@ -101,7 +102,8 @@ fn main() { model.forward(&backend, &inputs, &mut train_ctx); loss.derivative(&backend, &mut deltas, train_ctx.outputs(), &targets); model.backward(&backend, &deltas, &inputs, &mut train_ctx); - model.update(&backend, &optimizer, &inputs, &deltas, &mut train_ctx); + model.calc_gradients(&backend, &deltas, &inputs, &mut train_ctx); + model.optimize(&backend, &optimizer); } model.forward(&backend, &inputs0, &mut test_ctx); diff --git a/yarnn-models/mnist/src/lib.rs b/yarnn-models/mnist/src/lib.rs index 258f4dc..687dbdf 100644 --- a/yarnn-models/mnist/src/lib.rs +++ b/yarnn-models/mnist/src/lib.rs @@ -1,12 +1,10 @@ #![feature(trait_alias)] - pub use self::dense::MnistDenseModel; pub use self::conv::MnistConvModel; mod dense { use yarnn::model; - use yarnn::layer::*; use yarnn::layers::*; model! { @@ -30,7 +28,6 @@ mod dense { mod conv { use yarnn::model; - use yarnn::layer::*; use yarnn::layers::*; model! { @@ -62,4 +59,4 @@ mod conv { } } } -} \ No newline at end of file +} diff --git a/yarnn-models/vgg16/src/lib.rs b/yarnn-models/vgg16/src/lib.rs index 1fa8329..f90103d 100644 --- a/yarnn-models/vgg16/src/lib.rs +++ b/yarnn-models/vgg16/src/lib.rs @@ -1,7 +1,6 @@ #![feature(trait_alias)] use yarnn::layers::*; -use yarnn::layer::*; use yarnn::model; model! { diff --git a/yarnn/src/layer.rs b/yarnn/src/layer.rs index a4ef57b..543bf0d 100644 --- a/yarnn/src/layer.rs +++ b/yarnn/src/layer.rs @@ -1,14 +1,20 @@ use crate::backend::Backend; -use crate::optimizer::{Optimizer, Optimizable}; +use crate::optimizer::Optimizer; use crate::tensor::{Tensor, TensorShape}; -use core::marker::PhantomData; +// use core::marker::PhantomData; -pub trait Layer> { - type Config: Default; +pub trait Layer + where B: Backend, + O: Optimizer +{ + type Context: LayerContext; + fn name(&self) -> &str; - fn create(input_shape: TensorShape, cfg: Self::Config) -> Self; + fn param_count(&self) -> usize { + 0 + } #[inline] fn init(&mut self, _backend: &B) {} @@ -20,34 +26,39 @@ pub trait Layer> { self.input_shape() } - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor); - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, y: &B::Tensor); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context); + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context); + + #[inline] + fn calc_gradients(&mut self, _backend: &B, _dy: &B::Tensor, _x: &B::Tensor, _ctx: &mut Self::Context) {} + + #[inline] + fn optimize(&mut self, _backend: &B, _optimizer: &O) {} + + fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result { + writeln!(f, "{}{} -> {}[{}] -> {}", "".repeat(padding), self.input_shape(), self.name(), self.param_count(), self.output_shape())?; + + Ok(()) + } } -/// Temporary solution until I find a solution with problem of inference with specializations -impl Optimizable for T - where T: Layer, - B: Backend, +pub trait LayerExt: Layer + where B: Backend, O: Optimizer { - default fn calc_gradients(&mut self, _backend: &B, _inputs: &B::Tensor, _deltas: &B::Tensor) {} - default fn optimize(&mut self, _backend: &B, _optimizer: &O) {} -} + type Config: Default; -pub trait AbstractLayer, O: Optimizer>: core::fmt::Display { - type Context: LayerContext; + fn create(input_shape: TensorShape, cfg: Self::Config) -> Self; - fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context); - fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context); - fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context); - #[inline] - fn add_layer>(self, cfg: L::Config) -> crate::layers::Chain> + fn add_layer>(self, cfg: L::Config) -> crate::layers::Chain where Self: Sized { + let shape = self.output_shape(); + crate::layers::Chain::new( self, - LayerImpl::new(L::create(().into(), cfg)), + L::create(shape, cfg), ) } } @@ -57,14 +68,15 @@ pub trait LayerContext>: Default { fn deltas(&self) -> &B::Tensor; } -pub struct CommonLayerContext + +pub struct DefaultLayerContext where B: Backend, { pub outputs: B::Tensor, pub deltas: B::Tensor, } -impl Default for CommonLayerContext +impl Default for DefaultLayerContext where B: Backend, { fn default() -> Self { @@ -75,10 +87,10 @@ impl Default for CommonLayerContext } } -impl CommonLayerContext +impl DefaultLayerContext where B: Backend, { - pub fn update_deltas_bs(&mut self, bs: u32, input_shape: &TensorShape) { + pub fn update_deltas_shape(&mut self, bs: u32, input_shape: &TensorShape) { let mut new_deltas_shape = TensorShape::new1d(bs); new_deltas_shape.append(input_shape.clone()); @@ -87,7 +99,7 @@ impl CommonLayerContext } } - pub fn update_outputs_bs(&mut self, bs: u32, output_shape: &TensorShape) { + pub fn update_outputs_shape(&mut self, bs: u32, output_shape: &TensorShape) { let mut new_output_shape = TensorShape::new1d(bs); new_output_shape.append(output_shape.clone()); @@ -98,7 +110,7 @@ impl CommonLayerContext } } -impl LayerContext for CommonLayerContext +impl LayerContext for DefaultLayerContext where B: Backend, { #[inline] @@ -110,72 +122,4 @@ impl LayerContext for CommonLayerContext fn deltas(&self) -> &B::Tensor { &self.deltas } -} - -pub struct LayerImpl - where B: Backend, - O: Optimizer -{ - pub layer: L, - initialized: bool, - _m: PhantomData, -} - -impl LayerImpl - where B: Backend, - O: Optimizer, - L: Layer + Optimizable -{ - pub fn new(layer: L) -> Self { - Self { - layer, - initialized: false, - _m: Default::default(), - } - } -} - -impl core::fmt::Display for LayerImpl - where B: Backend, - O: Optimizer, - L: Layer -{ - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - write!(f, "{} -> ", self.layer.input_shape())?; - write!(f, "{}", self.layer.name())?; - writeln!(f, " -> {}", self.layer.output_shape())?; - - Ok(()) - } -} - -impl AbstractLayer for LayerImpl - where B: Backend, - O: Optimizer, - L: Layer + Optimizable -{ - type Context = CommonLayerContext; - - #[inline] - fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) { - if !self.initialized { - self.initialized = true; - self.layer.init(&backend); - } - - ctx.update_outputs_bs(inputs.shape().get(0), &self.layer.output_shape()); - self.layer.forward(&backend, &mut ctx.outputs, inputs); - } - - #[inline] - fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) { - ctx.update_deltas_bs(deltas.shape().get(0), &self.layer.input_shape()); - self.layer.backward(&backend, &mut ctx.deltas, deltas, inputs, &ctx.outputs); - } - - #[inline] - fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, _ctx: &mut Self::Context) { - self.layer.calc_gradients(&backend, inputs, deltas); - self.layer.optimize(&backend, &optimizer); - } -} +} \ No newline at end of file diff --git a/yarnn/src/layers/avgpool2d.rs b/yarnn/src/layers/avgpool2d.rs index c9a74c8..c207831 100644 --- a/yarnn/src/layers/avgpool2d.rs +++ b/yarnn/src/layers/avgpool2d.rs @@ -1,20 +1,20 @@ use crate::tensor::{Tensor, TensorShape}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; use crate::backend::{Backend, PaddingKind, BackendAvgPool2d, Conv2dInfo}; +use crate::optimizer::Optimizer; + use core::marker::PhantomData; pub struct AvgPool2dConfig { pub pool: (u32, u32), - pub strides: (u32, u32), - pub padding: PaddingKind, + pub strides: Option<(u32, u32)>, } impl Default for AvgPool2dConfig { fn default() -> Self { Self { pool: (2, 2), - strides: (2, 2), - padding: PaddingKind::Valid, + strides: None, } } } @@ -27,29 +27,16 @@ pub struct AvgPool2d _m: PhantomData } -impl Layer for AvgPool2d - where B: Backend + BackendAvgPool2d +impl Layer for AvgPool2d + where B: Backend + BackendAvgPool2d, + O: Optimizer { - type Config = AvgPool2dConfig; + type Context = DefaultLayerContext; fn name(&self) -> &str { "AvgPool2d" } - fn create(input_shape: TensorShape, config: Self::Config) -> Self { - assert!(input_shape.dims == 3); - - AvgPool2d { - input_shape, - conv_info: Conv2dInfo { - kernel: config.pool, - strides: config.strides, - padding: config.padding, - }, - _m: Default::default(), - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() @@ -72,7 +59,11 @@ impl Layer for AvgPool2d } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &Layer::::output_shape(self)); + + let y = &mut ctx.outputs; + assert_eq!(y.shape().dims, 4); assert_eq!(x.shape().dims, 4); @@ -80,10 +71,35 @@ impl Layer for AvgPool2d } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) { + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(x.shape().get(0), &self.input_shape); + + let dx = &mut ctx.deltas; + assert_eq!(dy.shape().dims, 4); assert_eq!(dx.shape().dims, 4); backend.avg_pool2d_backprop(dx, dy, x, &self.conv_info); } -} \ No newline at end of file +} + +impl LayerExt for AvgPool2d + where B: Backend + BackendAvgPool2d, + O: Optimizer +{ + type Config = AvgPool2dConfig; + + fn create(input_shape: TensorShape, config: Self::Config) -> Self { + assert!(input_shape.dims == 3); + + AvgPool2d { + input_shape, + conv_info: Conv2dInfo { + kernel: config.pool, + strides: config.strides.unwrap_or(config.pool), + padding: PaddingKind::Valid, + }, + _m: Default::default(), + } + } +} diff --git a/yarnn/src/layers/chain.rs b/yarnn/src/layers/chain.rs index 7e895ea..a5311e6 100644 --- a/yarnn/src/layers/chain.rs +++ b/yarnn/src/layers/chain.rs @@ -1,6 +1,7 @@ use crate::backend::Backend; -use crate::layer::{AbstractLayer, LayerContext}; +use crate::layer::{Layer, LayerContext}; use crate::optimizer::Optimizer; +use crate::tensor::TensorShape; use core::marker::PhantomData; @@ -47,8 +48,8 @@ impl LayerContext for ChainContext pub struct Chain where B: Backend, O: Optimizer, - L: AbstractLayer, - R: AbstractLayer, + L: Layer, + R: Layer, { left: L, right: R, @@ -58,8 +59,8 @@ pub struct Chain impl Chain where B: Backend, O: Optimizer, - L: AbstractLayer, - R: AbstractLayer, + L: Layer, + R: Layer, { pub fn new(left: L, right: R) -> Self { Self { @@ -70,30 +71,56 @@ impl Chain } } -impl core::fmt::Display for Chain +// impl core::fmt::Display for Chain +// where B: Backend, +// O: Optimizer, +// L: Layer, +// R: Layer, +// { +// fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { +// self.left.fmt(f)?; +// self.right.fmt(f)?; + +// Ok(()) +// } +// } + +impl Layer for Chain where B: Backend, O: Optimizer, - L: AbstractLayer, - R: AbstractLayer, -{ - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - self.left.fmt(f)?; - self.right.fmt(f)?; - - Ok(()) - } -} - -impl AbstractLayer for Chain - where B: Backend, - O: Optimizer, - L: AbstractLayer, - R: AbstractLayer, + L: Layer, + R: Layer, { type Context = ChainContext; #[inline] - fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) { + fn name(&self) -> &str { + "Chain Layer" + } + + #[inline] + fn param_count(&self) -> usize { + self.left.param_count() + self.right.param_count() + } + + #[inline] + fn init(&mut self, backend: &B) { + self.left.init(backend); + self.right.init(backend); + } + + #[inline] + fn input_shape(&self) -> TensorShape { + self.left.input_shape() + } + + #[inline] + fn output_shape(&self) -> TensorShape { + self.right.output_shape() + } + + #[inline] + fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) { self.left.forward(backend, inputs, &mut ctx.left); self.right.forward(backend, ctx.left.outputs(), &mut ctx.right); } @@ -105,8 +132,21 @@ impl AbstractLayer for Chain } #[inline] - fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) { - self.left.update(backend, optimizer, inputs, ctx.right.deltas(), &mut ctx.left); - self.right.update(backend, optimizer, ctx.left.outputs(), deltas, &mut ctx.right); + fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) { + self.left.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left); + self.right.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right); + } + + #[inline] + fn optimize(&mut self, backend: &B, optimizer: &O) { + self.left.optimize(backend, optimizer); + self.right.optimize(backend, optimizer); + } + + fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result { + self.left.fmt(f, padding)?; + self.right.fmt(f, padding)?; + + Ok(()) } } \ No newline at end of file diff --git a/yarnn/src/layers/conv2d.rs b/yarnn/src/layers/conv2d.rs index fdac335..5de8e07 100644 --- a/yarnn/src/layers/conv2d.rs +++ b/yarnn/src/layers/conv2d.rs @@ -1,8 +1,8 @@ use crate::tensor::{Tensor, TensorShape}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; use crate::params::Params; use crate::backend::{Backend, Conv2dInfo, PaddingKind, BackendBias, BackendConv2d, BackendScale}; -use crate::optimizer::{Optimizable, Optimizer}; +use crate::optimizer::Optimizer; pub struct Conv2dConfig { pub filters: u32, @@ -36,32 +36,24 @@ pub struct Conv2d biases: Params, } -impl Layer for Conv2d - where B: Backend + BackendConv2d + BackendBias, +impl Layer for Conv2d + where B: Backend + BackendConv2d + BackendBias + BackendScale, O: Optimizer { - type Config = Conv2dConfig; + type Context = DefaultLayerContext; fn name(&self) -> &str { "Conv2d" } - - fn create(input_shape: TensorShape, cfg: Self::Config) -> Self { - assert!(input_shape.dims == 3); - Conv2d { - input_shape, - units: cfg.filters, - conv_info: Conv2dInfo { - kernel: cfg.kernel, - padding: cfg.padding, - strides: cfg.strides, - }, - use_biases: cfg.biases, - filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)), - biases: Params::new((cfg.filters, )), + #[inline] + fn param_count(&self) -> usize { + if self.use_biases { + self.filters.params.shape().size() + self.biases.params.shape().size() + } else { + self.filters.params.shape().size() } - } + } fn init(&mut self, backend: &B) { self.filters.init_random(backend, self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0)); @@ -93,11 +85,12 @@ impl Layer for Conv2d } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - assert_eq!(y.shape().dims, 4); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { assert_eq!(x.shape().dims, 4); - backend.conv2d_forward(y, x, &self.filters.params, &self.conv_info); + ctx.update_outputs_shape(x.shape().get(0), &self.output_shape()); + + backend.conv2d_forward(&mut ctx.outputs, x, &self.filters.params, &self.conv_info); if self.use_biases { unimplemented!(); @@ -106,20 +99,16 @@ impl Layer for Conv2d } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) { + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) { assert_eq!(dy.shape().dims, 4); - assert_eq!(dx.shape().dims, 4); + + ctx.update_deltas_shape(x.shape().get(0), &self.input_shape); - backend.conv2d_backward_input(dx, dy, &self.filters.params, &self.conv_info); + backend.conv2d_backward_input(&mut ctx.deltas, dy, &self.filters.params, &self.conv_info); } -} -impl Optimizable for Conv2d - where B: Backend + BackendConv2d + BackendBias + BackendScale, - O: Optimizer -{ #[inline] - fn calc_gradients(&mut self, backend: &B, x: &B::Tensor, dy: &B::Tensor) { + fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) { assert_eq!(dy.shape().dims, 4); assert_eq!(x.shape().dims, 4); @@ -138,4 +127,28 @@ impl Optimizable for Conv2d // optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads); } } +} + +impl LayerExt for Conv2d + where B: Backend + BackendConv2d + BackendBias + BackendScale, + O: Optimizer +{ + type Config = Conv2dConfig; + + fn create(input_shape: TensorShape, cfg: Self::Config) -> Self { + assert!(input_shape.dims == 3); + + Conv2d { + input_shape, + units: cfg.filters, + conv_info: Conv2dInfo { + kernel: cfg.kernel, + padding: cfg.padding, + strides: cfg.strides, + }, + use_biases: cfg.biases, + filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)), + biases: Params::new((cfg.filters, )), + } + } } \ No newline at end of file diff --git a/yarnn/src/layers/flatten.rs b/yarnn/src/layers/flatten.rs index c4434e8..be5e349 100644 --- a/yarnn/src/layers/flatten.rs +++ b/yarnn/src/layers/flatten.rs @@ -1,8 +1,10 @@ -use crate::tensor::TensorShape; +use crate::tensor::{TensorShape, Tensor}; use crate::backend::{Backend, BackendCopy}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; + #[derive(Default)] pub struct FlattenConfig; @@ -13,22 +15,16 @@ pub struct Flatten _x: PhantomData, } -impl Layer for Flatten - where B: Backend + BackendCopy +impl Layer for Flatten + where B: Backend + BackendCopy, + O: Optimizer { - type Config = FlattenConfig; + type Context = DefaultLayerContext; fn name(&self) -> &str { "Flatten" } - fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { - Flatten { - input_shape, - _x: Default::default() - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() @@ -40,12 +36,30 @@ impl Layer for Flatten } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.copy(y, x); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &Layer::::output_shape(self)); + + backend.copy(&mut ctx.outputs, x); } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) { - backend.copy(dx, dy); + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(x.shape().get(0), &self.input_shape); + + backend.copy(&mut ctx.deltas, dy); + } +} + +impl LayerExt for Flatten + where B: Backend + BackendCopy, + O: Optimizer +{ + type Config = FlattenConfig; + + fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { + Flatten { + input_shape, + _x: Default::default() + } } } \ No newline at end of file diff --git a/yarnn/src/layers/linear.rs b/yarnn/src/layers/linear.rs index 87e4de8..7331748 100644 --- a/yarnn/src/layers/linear.rs +++ b/yarnn/src/layers/linear.rs @@ -1,8 +1,8 @@ use crate::tensor::{Tensor, TensorShape}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; use crate::params::Params; use crate::backend::{Backend, BackendGemm, BackendBias, BackendScale}; -use crate::optimizer::{Optimizable, Optimizer}; +use crate::optimizer::Optimizer; pub struct LinearConfig { pub units: u32, @@ -29,30 +29,25 @@ pub struct Linear biases: Params, } -impl Layer for Linear - where B: Backend + BackendGemm + BackendBias, +impl Layer for Linear + where B: Backend + BackendGemm + BackendBias + BackendScale, O: Optimizer { - type Config = LinearConfig; + type Context = DefaultLayerContext; fn name(&self) -> &str { "Linear" } - - fn create(input_shape: TensorShape, cfg: Self::Config) -> Self { - assert!(input_shape.dims == 1); - let inputs = input_shape.get(0); - - Linear { - inputs, - outputs: cfg.units, - use_biases: cfg.biases, - weights: Params::new((inputs, cfg.units)), - biases: Params::new((cfg.units, )), + #[inline] + fn param_count(&self) -> usize { + if self.use_biases { + self.weights.params.shape().size() + self.biases.params.shape().size() + } else { + self.weights.params.shape().size() } - } - + } + fn init(&mut self, backend: &B) { self.weights.init_random(backend, self.inputs + self.outputs); if self.use_biases { @@ -71,33 +66,32 @@ impl Layer for Linear } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.matmul(y, x, &self.weights.params); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &self.output_shape()); + + backend.matmul(&mut ctx.outputs, x, &self.weights.params); if self.use_biases { - backend.bias_add(y, &self.biases.params); + backend.bias_add(&mut ctx.outputs, &self.biases.params); } } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) { - backend.matmul_nt(dx, dy, &self.weights.params); + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(x.shape().get(0), &self.input_shape()); + + backend.matmul_nt(&mut ctx.deltas, dy, &self.weights.params); } -} -impl Optimizable for Linear - where B: Backend + BackendGemm + BackendBias + BackendScale, - O: Optimizer -{ - fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor) { - let prescaler = 1.0 / inputs.shape().get(0) as f32; + fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) { + let prescaler = 1.0 / x.shape().get(0) as f32; - backend.matmul_tn(&mut self.weights.grads, inputs, deltas); + backend.matmul_tn(&mut self.weights.grads, x, dy); backend.scale(&mut self.weights.grads, backend.scalar_f32(prescaler)); if self.use_biases { backend.scale(&mut self.biases.grads, backend.scalar_f32(prescaler)); - backend.bias_grad(&mut self.biases.grads, deltas); + backend.bias_grad(&mut self.biases.grads, &dy); } } @@ -109,4 +103,25 @@ impl Optimizable for Linear optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &mut self.biases.grads); } } +} + +impl LayerExt for Linear + where B: Backend + BackendGemm + BackendBias + BackendScale, + O: Optimizer +{ + type Config = LinearConfig; + + fn create(input_shape: TensorShape, cfg: Self::Config) -> Self { + assert!(input_shape.dims == 1); + + let inputs = input_shape.get(0); + + Linear { + inputs, + outputs: cfg.units, + use_biases: cfg.biases, + weights: Params::new((inputs, cfg.units)), + biases: Params::new((cfg.units, )), + } + } } \ No newline at end of file diff --git a/yarnn/src/layers/maxpool2d.rs b/yarnn/src/layers/maxpool2d.rs index d5e349b..ca6565a 100644 --- a/yarnn/src/layers/maxpool2d.rs +++ b/yarnn/src/layers/maxpool2d.rs @@ -1,6 +1,7 @@ -use crate::tensor::TensorShape; -use crate::layer::Layer; +use crate::tensor::{TensorShape, Tensor}; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; pub struct MaxPool2dConfig { @@ -25,29 +26,16 @@ pub struct MaxPool2d _m: PhantomData } -impl Layer for MaxPool2d +impl Layer for MaxPool2d where B: Backend + BackendMaxPool2d, + O: Optimizer { - type Config = MaxPool2dConfig; - + type Context = DefaultLayerContext; + fn name(&self) -> &str { "MaxPool2d" } - fn create(input_shape: TensorShape, config: Self::Config) -> Self { - assert!(input_shape.dims == 3); - - MaxPool2d { - input_shape, - conv_info: Conv2dInfo { - kernel: config.pool, - strides: config.strides.unwrap_or(config.pool), - padding: PaddingKind::Valid, - }, - _m: Default::default(), - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() @@ -70,12 +58,37 @@ impl Layer for MaxPool2d } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.max_pool2d(y, x, &self.conv_info) + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &Layer::::output_shape(self)); + + backend.max_pool2d(&mut ctx.outputs, x, &self.conv_info) } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) { - backend.max_pool2d_backprop(dx, dy, x, &self.conv_info); + fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(x.shape().get(0), &self.input_shape); + + backend.max_pool2d_backprop(&mut ctx.deltas, dy, x, &self.conv_info); + } +} + +impl LayerExt for MaxPool2d + where B: Backend + BackendMaxPool2d, + O: Optimizer +{ + type Config = MaxPool2dConfig; + + fn create(input_shape: TensorShape, config: Self::Config) -> Self { + assert!(input_shape.dims == 3); + + MaxPool2d { + input_shape, + conv_info: Conv2dInfo { + kernel: config.pool, + strides: config.strides.unwrap_or(config.pool), + padding: PaddingKind::Valid, + }, + _m: Default::default(), + } } } \ No newline at end of file diff --git a/yarnn/src/layers/relu.rs b/yarnn/src/layers/relu.rs index 8450a63..a064fa7 100644 --- a/yarnn/src/layers/relu.rs +++ b/yarnn/src/layers/relu.rs @@ -1,6 +1,7 @@ -use crate::tensor::TensorShape; +use crate::tensor::{Tensor, TensorShape}; use crate::backend::{Backend, BackendReLu}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; #[derive(Default)] @@ -13,34 +14,46 @@ pub struct ReLu _x: PhantomData, } -impl Layer for ReLu - where B: Backend + BackendReLu +impl Layer for ReLu + where B: Backend + BackendReLu, + O: Optimizer { - type Config = ReLuConfig; - + type Context = DefaultLayerContext; + fn name(&self) -> &str { "ReLU" } - fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { - ReLu { - input_shape, - _x: Default::default() - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.relu(y, x); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &self.input_shape); + + backend.relu(&mut ctx.outputs, x); } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) { - backend.relu_grad(dx, y, dy); + fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape); + + backend.relu_grad(&mut ctx.deltas, &ctx.outputs, dy); + } +} + +impl LayerExt for ReLu + where B: Backend + BackendReLu, + O: Optimizer +{ + type Config = ReLuConfig; + + fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { + ReLu { + input_shape, + _x: Default::default() + } } } \ No newline at end of file diff --git a/yarnn/src/layers/sigmoid.rs b/yarnn/src/layers/sigmoid.rs index 2af70f6..de11bd9 100644 --- a/yarnn/src/layers/sigmoid.rs +++ b/yarnn/src/layers/sigmoid.rs @@ -1,6 +1,7 @@ -use crate::tensor::TensorShape; +use crate::tensor::{Tensor, TensorShape}; use crate::backend::{Backend, BackendSigmoid}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; #[derive(Default)] @@ -13,21 +14,15 @@ pub struct Sigmoid _x: PhantomData, } -impl Layer for Sigmoid - where B: Backend + BackendSigmoid -{ - type Config = SigmoidConfig; - +impl Layer for Sigmoid + where B: Backend + BackendSigmoid, + O: Optimizer +{ + type Context = DefaultLayerContext; + fn name(&self) -> &str { "Sigmoid" } - - fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { - Sigmoid { - input_shape, - _x: Default::default() - } - } #[inline] fn input_shape(&self) -> TensorShape { @@ -35,12 +30,30 @@ impl Layer for Sigmoid } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.sigmoid(y, x); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &self.input_shape); + + backend.sigmoid(&mut ctx.outputs, x); } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) { - backend.sigmoid_grad(dx, y, dy); + fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape); + + backend.sigmoid_grad(&mut ctx.deltas, &ctx.outputs, dy); + } +} + +impl LayerExt for Sigmoid + where B: Backend + BackendSigmoid, + O: Optimizer +{ + type Config = SigmoidConfig; + + fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { + Sigmoid { + input_shape, + _x: Default::default() + } } } \ No newline at end of file diff --git a/yarnn/src/layers/softmax.rs b/yarnn/src/layers/softmax.rs index 7f202c1..2e0ce99 100644 --- a/yarnn/src/layers/softmax.rs +++ b/yarnn/src/layers/softmax.rs @@ -1,6 +1,7 @@ -use crate::tensor::TensorShape; +use crate::tensor::{Tensor, TensorShape}; use crate::backend::{Backend, BackendSoftmax}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; #[derive(Default)] @@ -13,34 +14,46 @@ pub struct Softmax _x: PhantomData, } -impl Layer for Softmax - where B: Backend + BackendSoftmax -{ - type Config = SoftmaxConfig; - +impl Layer for Softmax + where B: Backend + BackendSoftmax, + O: Optimizer +{ + type Context = DefaultLayerContext; + fn name(&self) -> &str { "Softmax" } - fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { - Softmax { - input_shape, - _x: Default::default() - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() } #[inline] - fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) { - backend.softmax(y, x); + fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &self.input_shape); + + backend.softmax(&mut ctx.outputs, x); } #[inline] - fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) { - backend.copy(dx, dy); + fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape); + + backend.copy(&mut ctx.deltas, dy); + } +} + +impl LayerExt for Softmax + where B: Backend + BackendSoftmax, + O: Optimizer +{ + type Config = SoftmaxConfig; + + fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self { + Softmax { + input_shape, + _x: Default::default() + } } } \ No newline at end of file diff --git a/yarnn/src/layers/zeropadding2d.rs b/yarnn/src/layers/zeropadding2d.rs index d10d3ba..760a820 100644 --- a/yarnn/src/layers/zeropadding2d.rs +++ b/yarnn/src/layers/zeropadding2d.rs @@ -1,6 +1,7 @@ -use crate::tensor::TensorShape; +use crate::tensor::{Tensor, TensorShape}; use crate::backend::{Backend, BackendCopy}; -use crate::layer::Layer; +use crate::layer::{Layer, LayerExt, DefaultLayerContext}; +use crate::optimizer::Optimizer; use core::marker::PhantomData; #[derive(Default)] @@ -16,23 +17,16 @@ pub struct ZeroPadding2d _x: PhantomData, } -impl Layer for ZeroPadding2d - where B: Backend + BackendCopy +impl Layer for ZeroPadding2d + where B: Backend + BackendCopy, + O: Optimizer { - type Config = ZeroPadding2dConfig; + type Context = DefaultLayerContext; fn name(&self) -> &str { "ZeroPadding2d" } - fn create(input_shape: TensorShape, config: Self::Config) -> Self { - ZeroPadding2d { - input_shape, - config, - _x: Default::default() - } - } - #[inline] fn input_shape(&self) -> TensorShape { self.input_shape.clone() @@ -50,12 +44,30 @@ impl Layer for ZeroPadding2d } #[inline] - fn forward(&self, _backend: &B, _y: &mut B::Tensor, _x: &B::Tensor) { + fn forward(&self, _backend: &B, x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_outputs_shape(x.shape().get(0), &Layer::::output_shape(self)); // backend.copy_with_padding(y, x, 0.0, (self.config.0, self.config.1, self.config.0, self.config.1)); } #[inline] - fn backward(&self, _backend: &B, _dx: &mut B::Tensor, _dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) { + fn backward(&mut self, _backend: &B, dy: &B::Tensor, _x: &B::Tensor, ctx: &mut Self::Context) { + ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape); // backend.copy(dx, dy); } +} + +impl LayerExt for ZeroPadding2d + where B: Backend + BackendCopy, + O: Optimizer +{ + type Config = ZeroPadding2dConfig; + + fn create(input_shape: TensorShape, config: Self::Config) -> Self { + ZeroPadding2d { + input_shape, + config, + _x: Default::default() + } + } + } \ No newline at end of file diff --git a/yarnn/src/macros.rs b/yarnn/src/macros.rs index e888e6b..36e4e19 100644 --- a/yarnn/src/macros.rs +++ b/yarnn/src/macros.rs @@ -8,22 +8,17 @@ macro_rules! sequential_type { #[macro_export] macro_rules! sequential_type_impl { - ($t:ty {$($tt:tt)*}) => { - $crate::layer::LayerImpl - }; + ($t:ty {$($tt:tt)*}) => ($t); + ($t:ty {$($xx:tt)*}, $($tt:tt)*) => { $crate::layers::Chain, - $crate::sequential_type_impl!($($tt)*) + $t, $crate::sequential_type_impl!($($tt)*) > }; - ($t:ty) => { - $crate::layer::LayerImpl - }; + ($t:ty) => ($t); ($t:ty, $($tt:tt)*) => { $crate::layers::Chain, - $crate::sequential_type_impl!($($tt)*) + $t, $crate::sequential_type_impl!($($tt)*) > }; } @@ -40,35 +35,29 @@ macro_rules! sequential { #[macro_export] macro_rules! sequential_impl { ($p:expr, $t:ty { $($name:ident : $val:expr),* }) => {{ - #[allow(unused_imports)] - use core::convert::TryInto; - #[allow(unused_mut)] - let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default(); + let mut params = <$t as $crate::layer::LayerExt>::Config::default(); $( - params.$name = ($val).try_into().unwrap_or($val); + params.$name = core::convert::TryInto::try_into($val).unwrap_or($val); )* - $crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create( + <$t as $crate::layer::LayerExt>::create( $p, params - )) + ) }}; ($p:expr, $t:ty { $($name:ident : $val:expr),* }, $($tt:tt)*) => {{ - #[allow(unused_imports)] - use core::convert::TryInto; - #[allow(unused_mut)] - let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default(); + let mut params = <$t as $crate::layer::LayerExt>::Config::default(); $( - params.$name = ($val).try_into().unwrap_or($val);; + params.$name = core::convert::TryInto::try_into($val).unwrap_or($val); )* - let layer = $crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create( + let layer = <$t as $crate::layer::LayerExt>::create( $p, params - )); + ); - let prev_shape = layer.layer.output_shape(); + let prev_shape = $crate::layer::Layer::::output_shape(&layer); $crate::layers::Chain::new( layer, $crate::sequential_impl! { prev_shape, $($tt)* }, @@ -94,12 +83,12 @@ macro_rules! sequential_type_ctx { #[macro_export] macro_rules! sequential_type_ctx_impl { ($t:ty {$($xx:tt)*}) => { - $crate::layer::CommonLayerContext + $crate::layer::DefaultLayerContext }; ($t:ty {$($xx:tt)*}, $($tt:tt)*) => { $crate::layers::ChainContext, + $crate::layer::DefaultLayerContext, $crate::sequential_type_ctx_impl!($($tt)*) > }; @@ -149,38 +138,86 @@ macro_rules! model_impl { } } - impl core::fmt::Display for $name - where B: $crate::backend::Backend + $trait, - O: $crate::optimizer::Optimizer - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - writeln!(f, "{} {{", stringify!($name))?; - write!(f, "{}", self.inner)?; - writeln!(f, "}}")?; + // impl core::fmt::Display for $name + // where B: $crate::backend::Backend + $trait, + // O: $crate::optimizer::Optimizer + // { + // fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + // writeln!(f, "{} {{", stringify!($name))?; + // write!(f, "{}", self.inner)?; + // writeln!(f, "}}")?; - Ok(()) - } - } + // Ok(()) + // } + // } - impl $crate::layer::AbstractLayer for $name + impl $crate::layer::Layer for $name where B: $crate::backend::Backend + $trait, O: $crate::optimizer::Optimizer { type Context = ctx::$name; #[inline] - fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) { - $crate::layer::AbstractLayer::forward(&mut self.inner, backend, inputs, ctx) + fn name(&self) -> &str { + stringify!($name) + } + + #[inline] + fn init(&mut self, backend: &B) { + self.inner.init(backend); + } + + #[inline] + fn param_count(&self) -> usize { + self.inner.param_count() + } + + #[inline] + fn input_shape(&self) -> $crate::tensor::TensorShape { + self.inner.input_shape() + } + + #[inline] + fn output_shape(&self) -> $crate::tensor::TensorShape { + self.inner.output_shape() + } + + #[inline] + fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) { + self.inner.forward(backend, inputs, ctx); } #[inline] fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) { - $crate::layer::AbstractLayer::backward(&mut self.inner, backend, deltas, inputs, ctx); + self.inner.backward(backend, deltas, inputs, ctx); } #[inline] - fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) { - $crate::layer::AbstractLayer::update(&mut self.inner, backend, optimizer, inputs, deltas, ctx); + fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) { + self.inner.calc_gradients(backend, deltas, inputs, ctx); + } + + #[inline] + fn optimize(&mut self, backend: &B, optimizer: &O) { + self.inner.optimize(backend, optimizer); + } + + fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result { + writeln!(f, "{}{}[{}] {{", "", self.name(), self.param_count())?; + self.inner.fmt(f, padding + 2)?; + write!(f, "}}")?; + + Ok(()) + } + } + + impl core::fmt::Display for $name + where B: $crate::backend::Backend + $trait, + O: $crate::optimizer::Optimizer + { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + >::fmt(self, f, 0) } } }; @@ -192,6 +229,7 @@ macro_rules! model { mod tmp { pub trait BackendDefault = $crate::backend::BackendReLu + $crate::backend::BackendBias + + $crate::backend::BackendScale + $crate::backend::BackendSigmoid + $crate::backend::BackendSoftmax + $crate::backend::BackendGemm diff --git a/yarnn/src/optimizer.rs b/yarnn/src/optimizer.rs index 01f393c..4234518 100644 --- a/yarnn/src/optimizer.rs +++ b/yarnn/src/optimizer.rs @@ -19,9 +19,4 @@ impl <'a, N, B: Backend, O: Optimizer> Optimizer for &'a O { fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) { (**self).update_params(backend, ctx, params, grads) } -} - -pub trait Optimizable, O: Optimizer> { - fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor); - fn optimize(&mut self, backend: &B, optimizer: &O); -} +} \ No newline at end of file diff --git a/yarnn/src/tensor.rs b/yarnn/src/tensor.rs index 84058d0..cebcec3 100644 --- a/yarnn/src/tensor.rs +++ b/yarnn/src/tensor.rs @@ -121,8 +121,8 @@ impl TensorShape { } } - pub fn append>(&mut self, s: S) -> &mut Self { - let s = s.into(); + pub fn append>(&mut self, s: S) -> &mut Self { + let s = s.borrow(); let sd = self.dims; for i in 0 .. s.dims {