Got rid of AbstractLayer

This commit is contained in:
Andrey Tkachenko 2019-07-16 00:10:36 +04:00
parent 82375f890b
commit 63d0428e32
17 changed files with 515 additions and 378 deletions

View File

@ -43,6 +43,7 @@ fn main() {
// let mut model = MnistDenseModel::new(28, 28, 1);
let mut model = MnistConvModel::new(28, 28, 1);
model.init(&backend);
println!("{}", &model);
@ -52,7 +53,7 @@ fn main() {
let loss = CrossEntropyLoss::new();
let Mnist { trn_img, trn_lbl, tst_img, tst_lbl, .. } = MnistBuilder::new()
.base_path("../../dataset/mnist")
.base_path("./datasets/mnist")
.label_format_digit()
.finalize();
@ -101,7 +102,8 @@ fn main() {
model.forward(&backend, &inputs, &mut train_ctx);
loss.derivative(&backend, &mut deltas, train_ctx.outputs(), &targets);
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
model.update(&backend, &optimizer, &inputs, &deltas, &mut train_ctx);
model.calc_gradients(&backend, &deltas, &inputs, &mut train_ctx);
model.optimize(&backend, &optimizer);
}
model.forward(&backend, &inputs0, &mut test_ctx);

View File

@ -1,12 +1,10 @@
#![feature(trait_alias)]
pub use self::dense::MnistDenseModel;
pub use self::conv::MnistConvModel;
mod dense {
use yarnn::model;
use yarnn::layer::*;
use yarnn::layers::*;
model! {
@ -30,7 +28,6 @@ mod dense {
mod conv {
use yarnn::model;
use yarnn::layer::*;
use yarnn::layers::*;
model! {
@ -62,4 +59,4 @@ mod conv {
}
}
}
}
}

View File

@ -1,7 +1,6 @@
#![feature(trait_alias)]
use yarnn::layers::*;
use yarnn::layer::*;
use yarnn::model;
model! {

View File

@ -1,14 +1,20 @@
use crate::backend::Backend;
use crate::optimizer::{Optimizer, Optimizable};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
// use core::marker::PhantomData;
pub trait Layer<N, B: Backend<N>> {
type Config: Default;
pub trait Layer<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
{
type Context: LayerContext<N, B>;
fn name(&self) -> &str;
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
fn param_count(&self) -> usize {
0
}
#[inline]
fn init(&mut self, _backend: &B) {}
@ -20,34 +26,39 @@ pub trait Layer<N, B: Backend<N>> {
self.input_shape()
}
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor);
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, y: &B::Tensor);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context);
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context);
#[inline]
fn calc_gradients(&mut self, _backend: &B, _dy: &B::Tensor, _x: &B::Tensor, _ctx: &mut Self::Context) {}
#[inline]
fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
writeln!(f, "{}{} -> {}[{}] -> {}", "".repeat(padding), self.input_shape(), self.name(), self.param_count(), self.output_shape())?;
Ok(())
}
}
/// Temporary solution until I find a solution with problem of inference with specializations
impl <T, N, B, O> Optimizable<N, B, O> for T
where T: Layer<N, B>,
B: Backend<N>,
pub trait LayerExt<N, B, O>: Layer<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
{
default fn calc_gradients(&mut self, _backend: &B, _inputs: &B::Tensor, _deltas: &B::Tensor) {}
default fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
}
type Config: Default;
pub trait AbstractLayer<N, B: Backend<N>, O: Optimizer<N, B>>: core::fmt::Display {
type Context: LayerContext<N, B>;
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context);
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context);
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context);
#[inline]
fn add_layer<L: Layer<N, B>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, LayerImpl<N, B, O, L>>
fn add_layer<L: LayerExt<N, B, O>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, L>
where Self: Sized
{
let shape = self.output_shape();
crate::layers::Chain::new(
self,
LayerImpl::new(L::create(().into(), cfg)),
L::create(shape, cfg),
)
}
}
@ -57,14 +68,15 @@ pub trait LayerContext<N, B: Backend<N>>: Default {
fn deltas(&self) -> &B::Tensor;
}
pub struct CommonLayerContext<N, B>
pub struct DefaultLayerContext<N, B>
where B: Backend<N>,
{
pub outputs: B::Tensor,
pub deltas: B::Tensor,
}
impl <N, B> Default for CommonLayerContext<N, B>
impl <N, B> Default for DefaultLayerContext<N, B>
where B: Backend<N>,
{
fn default() -> Self {
@ -75,10 +87,10 @@ impl <N, B> Default for CommonLayerContext<N, B>
}
}
impl <N, B> CommonLayerContext<N, B>
impl <N, B> DefaultLayerContext<N, B>
where B: Backend<N>,
{
pub fn update_deltas_bs(&mut self, bs: u32, input_shape: &TensorShape) {
pub fn update_deltas_shape(&mut self, bs: u32, input_shape: &TensorShape) {
let mut new_deltas_shape = TensorShape::new1d(bs);
new_deltas_shape.append(input_shape.clone());
@ -87,7 +99,7 @@ impl <N, B> CommonLayerContext<N, B>
}
}
pub fn update_outputs_bs(&mut self, bs: u32, output_shape: &TensorShape) {
pub fn update_outputs_shape(&mut self, bs: u32, output_shape: &TensorShape) {
let mut new_output_shape = TensorShape::new1d(bs);
new_output_shape.append(output_shape.clone());
@ -98,7 +110,7 @@ impl <N, B> CommonLayerContext<N, B>
}
}
impl <N, B> LayerContext<N, B> for CommonLayerContext<N, B>
impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
where B: Backend<N>,
{
#[inline]
@ -110,72 +122,4 @@ impl <N, B> LayerContext<N, B> for CommonLayerContext<N, B>
fn deltas(&self) -> &B::Tensor {
&self.deltas
}
}
pub struct LayerImpl <N, B, O, L>
where B: Backend<N>,
O: Optimizer<N, B>
{
pub layer: L,
initialized: bool,
_m: PhantomData<fn(N, B, O)>,
}
impl <N, B, O, L> LayerImpl<N, B, O, L>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B> + Optimizable<N, B, O>
{
pub fn new(layer: L) -> Self {
Self {
layer,
initialized: false,
_m: Default::default(),
}
}
}
impl <N, B, O, L> core::fmt::Display for LayerImpl<N, B, O, L>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B>
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{} -> ", self.layer.input_shape())?;
write!(f, "{}", self.layer.name())?;
writeln!(f, " -> {}", self.layer.output_shape())?;
Ok(())
}
}
impl <N, B, O, L> AbstractLayer<N, B, O> for LayerImpl<N, B, O, L>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B> + Optimizable<N, B, O>
{
type Context = CommonLayerContext<N, B>;
#[inline]
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
if !self.initialized {
self.initialized = true;
self.layer.init(&backend);
}
ctx.update_outputs_bs(inputs.shape().get(0), &self.layer.output_shape());
self.layer.forward(&backend, &mut ctx.outputs, inputs);
}
#[inline]
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_bs(deltas.shape().get(0), &self.layer.input_shape());
self.layer.backward(&backend, &mut ctx.deltas, deltas, inputs, &ctx.outputs);
}
#[inline]
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, _ctx: &mut Self::Context) {
self.layer.calc_gradients(&backend, inputs, deltas);
self.layer.optimize(&backend, &optimizer);
}
}
}

View File

@ -1,20 +1,20 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::backend::{Backend, PaddingKind, BackendAvgPool2d, Conv2dInfo};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
pub struct AvgPool2dConfig {
pub pool: (u32, u32),
pub strides: (u32, u32),
pub padding: PaddingKind,
pub strides: Option<(u32, u32)>,
}
impl Default for AvgPool2dConfig {
fn default() -> Self {
Self {
pool: (2, 2),
strides: (2, 2),
padding: PaddingKind::Valid,
strides: None,
}
}
}
@ -27,29 +27,16 @@ pub struct AvgPool2d<N, B>
_m: PhantomData<fn(N, B)>
}
impl <N, B> Layer<N, B> for AvgPool2d<N, B>
where B: Backend<N> + BackendAvgPool2d<N>
impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
where B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>
{
type Config = AvgPool2dConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"AvgPool2d"
}
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
assert!(input_shape.dims == 3);
AvgPool2d {
input_shape,
conv_info: Conv2dInfo {
kernel: config.pool,
strides: config.strides,
padding: config.padding,
},
_m: Default::default(),
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -72,7 +59,11 @@ impl <N, B> Layer<N, B> for AvgPool2d<N, B>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
let y = &mut ctx.outputs;
assert_eq!(y.shape().dims, 4);
assert_eq!(x.shape().dims, 4);
@ -80,10 +71,35 @@ impl <N, B> Layer<N, B> for AvgPool2d<N, B>
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
let dx = &mut ctx.deltas;
assert_eq!(dy.shape().dims, 4);
assert_eq!(dx.shape().dims, 4);
backend.avg_pool2d_backprop(dx, dy, x, &self.conv_info);
}
}
}
impl <N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
where B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>
{
type Config = AvgPool2dConfig;
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
assert!(input_shape.dims == 3);
AvgPool2d {
input_shape,
conv_info: Conv2dInfo {
kernel: config.pool,
strides: config.strides.unwrap_or(config.pool),
padding: PaddingKind::Valid,
},
_m: Default::default(),
}
}
}

View File

@ -1,6 +1,7 @@
use crate::backend::Backend;
use crate::layer::{AbstractLayer, LayerContext};
use crate::layer::{Layer, LayerContext};
use crate::optimizer::Optimizer;
use crate::tensor::TensorShape;
use core::marker::PhantomData;
@ -47,8 +48,8 @@ impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
pub struct Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: AbstractLayer<N, B, O>,
R: AbstractLayer<N, B, O>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
left: L,
right: R,
@ -58,8 +59,8 @@ pub struct Chain<N, B, O, L, R>
impl<N, B, O, L, R> Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: AbstractLayer<N, B, O>,
R: AbstractLayer<N, B, O>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
pub fn new(left: L, right: R) -> Self {
Self {
@ -70,30 +71,56 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
}
}
impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
// where B: Backend<N>,
// O: Optimizer<N, B>,
// L: Layer<N, B, O>,
// R: Layer<N, B, O>,
// {
// fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
// self.left.fmt(f)?;
// self.right.fmt(f)?;
// Ok(())
// }
// }
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: AbstractLayer<N, B, O>,
R: AbstractLayer<N, B, O>,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
self.left.fmt(f)?;
self.right.fmt(f)?;
Ok(())
}
}
impl<N, B, O, L, R> AbstractLayer<N, B, O> for Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: AbstractLayer<N, B, O>,
R: AbstractLayer<N, B, O>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
type Context = ChainContext<N, B, L::Context, R::Context>;
#[inline]
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
fn name(&self) -> &str {
"Chain Layer"
}
#[inline]
fn param_count(&self) -> usize {
self.left.param_count() + self.right.param_count()
}
#[inline]
fn init(&mut self, backend: &B) {
self.left.init(backend);
self.right.init(backend);
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.left.input_shape()
}
#[inline]
fn output_shape(&self) -> TensorShape {
self.right.output_shape()
}
#[inline]
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.left.forward(backend, inputs, &mut ctx.left);
self.right.forward(backend, ctx.left.outputs(), &mut ctx.right);
}
@ -105,8 +132,21 @@ impl<N, B, O, L, R> AbstractLayer<N, B, O> for Chain<N, B, O, L, R>
}
#[inline]
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) {
self.left.update(backend, optimizer, inputs, ctx.right.deltas(), &mut ctx.left);
self.right.update(backend, optimizer, ctx.left.outputs(), deltas, &mut ctx.right);
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.left.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
self.right.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
}
#[inline]
fn optimize(&mut self, backend: &B, optimizer: &O) {
self.left.optimize(backend, optimizer);
self.right.optimize(backend, optimizer);
}
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
self.left.fmt(f, padding)?;
self.right.fmt(f, padding)?;
Ok(())
}
}

View File

@ -1,8 +1,8 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::params::Params;
use crate::backend::{Backend, Conv2dInfo, PaddingKind, BackendBias, BackendConv2d, BackendScale};
use crate::optimizer::{Optimizable, Optimizer};
use crate::optimizer::Optimizer;
pub struct Conv2dConfig {
pub filters: u32,
@ -36,32 +36,24 @@ pub struct Conv2d<N, B, O>
biases: Params<N, B, O>,
}
impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N>,
impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
type Config = Conv2dConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Conv2d"
}
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
assert!(input_shape.dims == 3);
Conv2d {
input_shape,
units: cfg.filters,
conv_info: Conv2dInfo {
kernel: cfg.kernel,
padding: cfg.padding,
strides: cfg.strides,
},
use_biases: cfg.biases,
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
biases: Params::new((cfg.filters, )),
#[inline]
fn param_count(&self) -> usize {
if self.use_biases {
self.filters.params.shape().size() + self.biases.params.shape().size()
} else {
self.filters.params.shape().size()
}
}
}
fn init(&mut self, backend: &B) {
self.filters.init_random(backend, self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0));
@ -93,11 +85,12 @@ impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
assert_eq!(y.shape().dims, 4);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
assert_eq!(x.shape().dims, 4);
backend.conv2d_forward(y, x, &self.filters.params, &self.conv_info);
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
backend.conv2d_forward(&mut ctx.outputs, x, &self.filters.params, &self.conv_info);
if self.use_biases {
unimplemented!();
@ -106,20 +99,16 @@ impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
assert_eq!(dy.shape().dims, 4);
assert_eq!(dx.shape().dims, 4);
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
backend.conv2d_backward_input(dx, dy, &self.filters.params, &self.conv_info);
backend.conv2d_backward_input(&mut ctx.deltas, dy, &self.filters.params, &self.conv_info);
}
}
impl <N, B, O> Optimizable<N, B, O> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
#[inline]
fn calc_gradients(&mut self, backend: &B, x: &B::Tensor, dy: &B::Tensor) {
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
assert_eq!(dy.shape().dims, 4);
assert_eq!(x.shape().dims, 4);
@ -138,4 +127,28 @@ impl <N, B, O> Optimizable<N, B, O> for Conv2d<N, B, O>
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
}
}
}
impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
type Config = Conv2dConfig;
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
assert!(input_shape.dims == 3);
Conv2d {
input_shape,
units: cfg.filters,
conv_info: Conv2dInfo {
kernel: cfg.kernel,
padding: cfg.padding,
strides: cfg.strides,
},
use_biases: cfg.biases,
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
biases: Params::new((cfg.filters, )),
}
}
}

View File

@ -1,8 +1,10 @@
use crate::tensor::TensorShape;
use crate::tensor::{TensorShape, Tensor};
use crate::backend::{Backend, BackendCopy};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
#[derive(Default)]
pub struct FlattenConfig;
@ -13,22 +15,16 @@ pub struct Flatten<N, B>
_x: PhantomData<fn(N, B)>,
}
impl <N, B> Layer<N, B> for Flatten<N, B>
where B: Backend<N> + BackendCopy<N>
impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
{
type Config = FlattenConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Flatten"
}
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Flatten {
input_shape,
_x: Default::default()
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -40,12 +36,30 @@ impl <N, B> Layer<N, B> for Flatten<N, B>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.copy(y, x);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
backend.copy(&mut ctx.outputs, x);
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
backend.copy(dx, dy);
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
backend.copy(&mut ctx.deltas, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for Flatten<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
{
type Config = FlattenConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Flatten {
input_shape,
_x: Default::default()
}
}
}

View File

@ -1,8 +1,8 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::params::Params;
use crate::backend::{Backend, BackendGemm, BackendBias, BackendScale};
use crate::optimizer::{Optimizable, Optimizer};
use crate::optimizer::Optimizer;
pub struct LinearConfig {
pub units: u32,
@ -29,30 +29,25 @@ pub struct Linear<N, B, O>
biases: Params<N, B, O>,
}
impl <N, B, O> Layer<N, B> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N>,
impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
type Config = LinearConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Linear"
}
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
assert!(input_shape.dims == 1);
let inputs = input_shape.get(0);
Linear {
inputs,
outputs: cfg.units,
use_biases: cfg.biases,
weights: Params::new((inputs, cfg.units)),
biases: Params::new((cfg.units, )),
#[inline]
fn param_count(&self) -> usize {
if self.use_biases {
self.weights.params.shape().size() + self.biases.params.shape().size()
} else {
self.weights.params.shape().size()
}
}
}
fn init(&mut self, backend: &B) {
self.weights.init_random(backend, self.inputs + self.outputs);
if self.use_biases {
@ -71,33 +66,32 @@ impl <N, B, O> Layer<N, B> for Linear<N, B, O>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.matmul(y, x, &self.weights.params);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
backend.matmul(&mut ctx.outputs, x, &self.weights.params);
if self.use_biases {
backend.bias_add(y, &self.biases.params);
backend.bias_add(&mut ctx.outputs, &self.biases.params);
}
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
backend.matmul_nt(dx, dy, &self.weights.params);
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape());
backend.matmul_nt(&mut ctx.deltas, dy, &self.weights.params);
}
}
impl <N, B, O> Optimizable<N, B, O> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor) {
let prescaler = 1.0 / inputs.shape().get(0) as f32;
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
let prescaler = 1.0 / x.shape().get(0) as f32;
backend.matmul_tn(&mut self.weights.grads, inputs, deltas);
backend.matmul_tn(&mut self.weights.grads, x, dy);
backend.scale(&mut self.weights.grads, backend.scalar_f32(prescaler));
if self.use_biases {
backend.scale(&mut self.biases.grads, backend.scalar_f32(prescaler));
backend.bias_grad(&mut self.biases.grads, deltas);
backend.bias_grad(&mut self.biases.grads, &dy);
}
}
@ -109,4 +103,25 @@ impl <N, B, O> Optimizable<N, B, O> for Linear<N, B, O>
optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &mut self.biases.grads);
}
}
}
impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
{
type Config = LinearConfig;
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
assert!(input_shape.dims == 1);
let inputs = input_shape.get(0);
Linear {
inputs,
outputs: cfg.units,
use_biases: cfg.biases,
weights: Params::new((inputs, cfg.units)),
biases: Params::new((cfg.units, )),
}
}
}

View File

@ -1,6 +1,7 @@
use crate::tensor::TensorShape;
use crate::layer::Layer;
use crate::tensor::{TensorShape, Tensor};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
pub struct MaxPool2dConfig {
@ -25,29 +26,16 @@ pub struct MaxPool2d<N, B>
_m: PhantomData<fn(N, B)>
}
impl <N, B> Layer<N, B> for MaxPool2d<N, B>
impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
where B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>
{
type Config = MaxPool2dConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"MaxPool2d"
}
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
assert!(input_shape.dims == 3);
MaxPool2d {
input_shape,
conv_info: Conv2dInfo {
kernel: config.pool,
strides: config.strides.unwrap_or(config.pool),
padding: PaddingKind::Valid,
},
_m: Default::default(),
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -70,12 +58,37 @@ impl <N, B> Layer<N, B> for MaxPool2d<N, B>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.max_pool2d(y, x, &self.conv_info)
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
backend.max_pool2d(&mut ctx.outputs, x, &self.conv_info)
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
backend.max_pool2d_backprop(dx, dy, x, &self.conv_info);
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
backend.max_pool2d_backprop(&mut ctx.deltas, dy, x, &self.conv_info);
}
}
impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
where B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>
{
type Config = MaxPool2dConfig;
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
assert!(input_shape.dims == 3);
MaxPool2d {
input_shape,
conv_info: Conv2dInfo {
kernel: config.pool,
strides: config.strides.unwrap_or(config.pool),
padding: PaddingKind::Valid,
},
_m: Default::default(),
}
}
}

View File

@ -1,6 +1,7 @@
use crate::tensor::TensorShape;
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendReLu};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
#[derive(Default)]
@ -13,34 +14,46 @@ pub struct ReLu<N, B>
_x: PhantomData<fn(N, B)>,
}
impl <N, B> Layer<N, B> for ReLu<N, B>
where B: Backend<N> + BackendReLu<N>
impl <N, B, O> Layer<N, B, O> for ReLu<N, B>
where B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>
{
type Config = ReLuConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"ReLU"
}
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
ReLu {
input_shape,
_x: Default::default()
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.relu(y, x);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
backend.relu(&mut ctx.outputs, x);
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) {
backend.relu_grad(dx, y, dy);
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
backend.relu_grad(&mut ctx.deltas, &ctx.outputs, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for ReLu<N, B>
where B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>
{
type Config = ReLuConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
ReLu {
input_shape,
_x: Default::default()
}
}
}

View File

@ -1,6 +1,7 @@
use crate::tensor::TensorShape;
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendSigmoid};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
#[derive(Default)]
@ -13,21 +14,15 @@ pub struct Sigmoid<N, B>
_x: PhantomData<fn(N, B)>,
}
impl <N, B> Layer<N, B> for Sigmoid<N, B>
where B: Backend<N> + BackendSigmoid<N>
{
type Config = SigmoidConfig;
impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
where B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Sigmoid"
}
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Sigmoid {
input_shape,
_x: Default::default()
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
@ -35,12 +30,30 @@ impl <N, B> Layer<N, B> for Sigmoid<N, B>
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.sigmoid(y, x);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
backend.sigmoid(&mut ctx.outputs, x);
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) {
backend.sigmoid_grad(dx, y, dy);
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
backend.sigmoid_grad(&mut ctx.deltas, &ctx.outputs, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
where B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>
{
type Config = SigmoidConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Sigmoid {
input_shape,
_x: Default::default()
}
}
}

View File

@ -1,6 +1,7 @@
use crate::tensor::TensorShape;
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendSoftmax};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
#[derive(Default)]
@ -13,34 +14,46 @@ pub struct Softmax<N, B>
_x: PhantomData<fn(N, B)>,
}
impl <N, B> Layer<N, B> for Softmax<N, B>
where B: Backend<N> + BackendSoftmax<N>
{
type Config = SoftmaxConfig;
impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
where B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Softmax"
}
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Softmax {
input_shape,
_x: Default::default()
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
}
#[inline]
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
backend.softmax(y, x);
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
backend.softmax(&mut ctx.outputs, x);
}
#[inline]
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
backend.copy(dx, dy);
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
backend.copy(&mut ctx.deltas, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for Softmax<N, B>
where B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>
{
type Config = SoftmaxConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Softmax {
input_shape,
_x: Default::default()
}
}
}

View File

@ -1,6 +1,7 @@
use crate::tensor::TensorShape;
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendCopy};
use crate::layer::Layer;
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
#[derive(Default)]
@ -16,23 +17,16 @@ pub struct ZeroPadding2d<N, B>
_x: PhantomData<fn(N, B)>,
}
impl <N, B> Layer<N, B> for ZeroPadding2d<N, B>
where B: Backend<N> + BackendCopy<N>
impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
{
type Config = ZeroPadding2dConfig;
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"ZeroPadding2d"
}
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
ZeroPadding2d {
input_shape,
config,
_x: Default::default()
}
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -50,12 +44,30 @@ impl <N, B> Layer<N, B> for ZeroPadding2d<N, B>
}
#[inline]
fn forward(&self, _backend: &B, _y: &mut B::Tensor, _x: &B::Tensor) {
fn forward(&self, _backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
// backend.copy_with_padding(y, x, 0.0, (self.config.0, self.config.1, self.config.0, self.config.1));
}
#[inline]
fn backward(&self, _backend: &B, _dx: &mut B::Tensor, _dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
fn backward(&mut self, _backend: &B, dy: &B::Tensor, _x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
// backend.copy(dx, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
{
type Config = ZeroPadding2dConfig;
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
ZeroPadding2d {
input_shape,
config,
_x: Default::default()
}
}
}

View File

@ -8,22 +8,17 @@ macro_rules! sequential_type {
#[macro_export]
macro_rules! sequential_type_impl {
($t:ty {$($tt:tt)*}) => {
$crate::layer::LayerImpl<N, B, O, $t>
};
($t:ty {$($tt:tt)*}) => ($t);
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
$crate::layers::Chain<N, B, O,
$crate::layer::LayerImpl<N, B, O, $t>,
$crate::sequential_type_impl!($($tt)*)
$t, $crate::sequential_type_impl!($($tt)*)
>
};
($t:ty) => {
$crate::layer::LayerImpl<N, B, O, $t>
};
($t:ty) => ($t);
($t:ty, $($tt:tt)*) => {
$crate::layers::Chain<N, B, O,
$crate::layer::LayerImpl<N, B, O, $t>,
$crate::sequential_type_impl!($($tt)*)
$t, $crate::sequential_type_impl!($($tt)*)
>
};
}
@ -40,35 +35,29 @@ macro_rules! sequential {
#[macro_export]
macro_rules! sequential_impl {
($p:expr, $t:ty { $($name:ident : $val:expr),* }) => {{
#[allow(unused_imports)]
use core::convert::TryInto;
#[allow(unused_mut)]
let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default();
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
$(
params.$name = ($val).try_into().unwrap_or($val);
params.$name = core::convert::TryInto::try_into($val).unwrap_or($val);
)*
$crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create(
<$t as $crate::layer::LayerExt<N, B, O>>::create(
$p, params
))
)
}};
($p:expr, $t:ty { $($name:ident : $val:expr),* }, $($tt:tt)*) => {{
#[allow(unused_imports)]
use core::convert::TryInto;
#[allow(unused_mut)]
let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default();
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
$(
params.$name = ($val).try_into().unwrap_or($val);;
params.$name = core::convert::TryInto::try_into($val).unwrap_or($val);
)*
let layer = $crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create(
let layer = <$t as $crate::layer::LayerExt<N, B, O>>::create(
$p, params
));
);
let prev_shape = layer.layer.output_shape();
let prev_shape = $crate::layer::Layer::<N, B, O>::output_shape(&layer);
$crate::layers::Chain::new(
layer, $crate::sequential_impl! { prev_shape, $($tt)* },
@ -94,12 +83,12 @@ macro_rules! sequential_type_ctx {
#[macro_export]
macro_rules! sequential_type_ctx_impl {
($t:ty {$($xx:tt)*}) => {
$crate::layer::CommonLayerContext<N, B>
$crate::layer::DefaultLayerContext<N, B>
};
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
$crate::layers::ChainContext<N, B,
$crate::layer::CommonLayerContext<N, B>,
$crate::layer::DefaultLayerContext<N, B>,
$crate::sequential_type_ctx_impl!($($tt)*)
>
};
@ -149,38 +138,86 @@ macro_rules! model_impl {
}
}
impl<N, B, O> core::fmt::Display for $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
writeln!(f, "{} {{", stringify!($name))?;
write!(f, "{}", self.inner)?;
writeln!(f, "}}")?;
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
// where B: $crate::backend::Backend<N> + $trait,
// O: $crate::optimizer::Optimizer<N, B>
// {
// fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
// writeln!(f, "{} {{", stringify!($name))?;
// write!(f, "{}", self.inner)?;
// writeln!(f, "}}")?;
Ok(())
}
}
// Ok(())
// }
// }
impl<N, B, O> $crate::layer::AbstractLayer<N, B, O> for $name<N, B, O>
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
type Context = ctx::$name<N, B>;
#[inline]
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
$crate::layer::AbstractLayer::forward(&mut self.inner, backend, inputs, ctx)
fn name(&self) -> &str {
stringify!($name)
}
#[inline]
fn init(&mut self, backend: &B) {
self.inner.init(backend);
}
#[inline]
fn param_count(&self) -> usize {
self.inner.param_count()
}
#[inline]
fn input_shape(&self) -> $crate::tensor::TensorShape {
self.inner.input_shape()
}
#[inline]
fn output_shape(&self) -> $crate::tensor::TensorShape {
self.inner.output_shape()
}
#[inline]
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.inner.forward(backend, inputs, ctx);
}
#[inline]
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
$crate::layer::AbstractLayer::backward(&mut self.inner, backend, deltas, inputs, ctx);
self.inner.backward(backend, deltas, inputs, ctx);
}
#[inline]
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) {
$crate::layer::AbstractLayer::update(&mut self.inner, backend, optimizer, inputs, deltas, ctx);
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.inner.calc_gradients(backend, deltas, inputs, ctx);
}
#[inline]
fn optimize(&mut self, backend: &B, optimizer: &O) {
self.inner.optimize(backend, optimizer);
}
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
writeln!(f, "{}{}[{}] {{", "", self.name(), self.param_count())?;
self.inner.fmt(f, padding + 2)?;
write!(f, "}}")?;
Ok(())
}
}
impl<N, B, O> core::fmt::Display for $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
<Self as $crate::layer::Layer<_, _, _>>::fmt(self, f, 0)
}
}
};
@ -192,6 +229,7 @@ macro_rules! model {
mod tmp {
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
+ $crate::backend::BackendBias<N>
+ $crate::backend::BackendScale<N>
+ $crate::backend::BackendSigmoid<N>
+ $crate::backend::BackendSoftmax<N>
+ $crate::backend::BackendGemm<N>

View File

@ -19,9 +19,4 @@ impl <'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
(**self).update_params(backend, ctx, params, grads)
}
}
pub trait Optimizable<N, B: Backend<N>, O: Optimizer<N, B>> {
fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor);
fn optimize(&mut self, backend: &B, optimizer: &O);
}
}

View File

@ -121,8 +121,8 @@ impl TensorShape {
}
}
pub fn append<S: Into<TensorShape>>(&mut self, s: S) -> &mut Self {
let s = s.into();
pub fn append<S: core::borrow::Borrow<TensorShape>>(&mut self, s: S) -> &mut Self {
let s = s.borrow();
let sd = self.dims;
for i in 0 .. s.dims {