Got rid of AbstractLayer
This commit is contained in:
parent
82375f890b
commit
63d0428e32
@ -43,6 +43,7 @@ fn main() {
|
||||
|
||||
// let mut model = MnistDenseModel::new(28, 28, 1);
|
||||
let mut model = MnistConvModel::new(28, 28, 1);
|
||||
model.init(&backend);
|
||||
|
||||
println!("{}", &model);
|
||||
|
||||
@ -52,7 +53,7 @@ fn main() {
|
||||
let loss = CrossEntropyLoss::new();
|
||||
|
||||
let Mnist { trn_img, trn_lbl, tst_img, tst_lbl, .. } = MnistBuilder::new()
|
||||
.base_path("../../dataset/mnist")
|
||||
.base_path("./datasets/mnist")
|
||||
.label_format_digit()
|
||||
.finalize();
|
||||
|
||||
@ -101,7 +102,8 @@ fn main() {
|
||||
model.forward(&backend, &inputs, &mut train_ctx);
|
||||
loss.derivative(&backend, &mut deltas, train_ctx.outputs(), &targets);
|
||||
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
|
||||
model.update(&backend, &optimizer, &inputs, &deltas, &mut train_ctx);
|
||||
model.calc_gradients(&backend, &deltas, &inputs, &mut train_ctx);
|
||||
model.optimize(&backend, &optimizer);
|
||||
}
|
||||
|
||||
model.forward(&backend, &inputs0, &mut test_ctx);
|
||||
|
@ -1,12 +1,10 @@
|
||||
#![feature(trait_alias)]
|
||||
|
||||
|
||||
pub use self::dense::MnistDenseModel;
|
||||
pub use self::conv::MnistConvModel;
|
||||
|
||||
mod dense {
|
||||
use yarnn::model;
|
||||
use yarnn::layer::*;
|
||||
use yarnn::layers::*;
|
||||
|
||||
model! {
|
||||
@ -30,7 +28,6 @@ mod dense {
|
||||
|
||||
mod conv {
|
||||
use yarnn::model;
|
||||
use yarnn::layer::*;
|
||||
use yarnn::layers::*;
|
||||
|
||||
model! {
|
||||
@ -62,4 +59,4 @@ mod conv {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
#![feature(trait_alias)]
|
||||
|
||||
use yarnn::layers::*;
|
||||
use yarnn::layer::*;
|
||||
use yarnn::model;
|
||||
|
||||
model! {
|
||||
|
@ -1,14 +1,20 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::optimizer::{Optimizer, Optimizable};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
use core::marker::PhantomData;
|
||||
// use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub trait Layer<N, B: Backend<N>> {
|
||||
type Config: Default;
|
||||
pub trait Layer<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Context: LayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str;
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
|
||||
fn param_count(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn init(&mut self, _backend: &B) {}
|
||||
@ -20,34 +26,39 @@ pub trait Layer<N, B: Backend<N>> {
|
||||
self.input_shape()
|
||||
}
|
||||
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor);
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, y: &B::Tensor);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context);
|
||||
|
||||
#[inline]
|
||||
fn calc_gradients(&mut self, _backend: &B, _dy: &B::Tensor, _x: &B::Tensor, _ctx: &mut Self::Context) {}
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
|
||||
writeln!(f, "{}{} -> {}[{}] -> {}", "".repeat(padding), self.input_shape(), self.name(), self.param_count(), self.output_shape())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Temporary solution until I find a solution with problem of inference with specializations
|
||||
impl <T, N, B, O> Optimizable<N, B, O> for T
|
||||
where T: Layer<N, B>,
|
||||
B: Backend<N>,
|
||||
pub trait LayerExt<N, B, O>: Layer<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
default fn calc_gradients(&mut self, _backend: &B, _inputs: &B::Tensor, _deltas: &B::Tensor) {}
|
||||
default fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
|
||||
}
|
||||
type Config: Default;
|
||||
|
||||
pub trait AbstractLayer<N, B: Backend<N>, O: Optimizer<N, B>>: core::fmt::Display {
|
||||
type Context: LayerContext<N, B>;
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
|
||||
|
||||
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context);
|
||||
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context);
|
||||
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context);
|
||||
|
||||
#[inline]
|
||||
fn add_layer<L: Layer<N, B>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, LayerImpl<N, B, O, L>>
|
||||
fn add_layer<L: LayerExt<N, B, O>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, L>
|
||||
where Self: Sized
|
||||
{
|
||||
let shape = self.output_shape();
|
||||
|
||||
crate::layers::Chain::new(
|
||||
self,
|
||||
LayerImpl::new(L::create(().into(), cfg)),
|
||||
L::create(shape, cfg),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -57,14 +68,15 @@ pub trait LayerContext<N, B: Backend<N>>: Default {
|
||||
fn deltas(&self) -> &B::Tensor;
|
||||
}
|
||||
|
||||
pub struct CommonLayerContext<N, B>
|
||||
|
||||
pub struct DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
{
|
||||
pub outputs: B::Tensor,
|
||||
pub deltas: B::Tensor,
|
||||
}
|
||||
|
||||
impl <N, B> Default for CommonLayerContext<N, B>
|
||||
impl <N, B> Default for DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
@ -75,10 +87,10 @@ impl <N, B> Default for CommonLayerContext<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B> CommonLayerContext<N, B>
|
||||
impl <N, B> DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
{
|
||||
pub fn update_deltas_bs(&mut self, bs: u32, input_shape: &TensorShape) {
|
||||
pub fn update_deltas_shape(&mut self, bs: u32, input_shape: &TensorShape) {
|
||||
let mut new_deltas_shape = TensorShape::new1d(bs);
|
||||
new_deltas_shape.append(input_shape.clone());
|
||||
|
||||
@ -87,7 +99,7 @@ impl <N, B> CommonLayerContext<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_outputs_bs(&mut self, bs: u32, output_shape: &TensorShape) {
|
||||
pub fn update_outputs_shape(&mut self, bs: u32, output_shape: &TensorShape) {
|
||||
let mut new_output_shape = TensorShape::new1d(bs);
|
||||
|
||||
new_output_shape.append(output_shape.clone());
|
||||
@ -98,7 +110,7 @@ impl <N, B> CommonLayerContext<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B> LayerContext<N, B> for CommonLayerContext<N, B>
|
||||
impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
{
|
||||
#[inline]
|
||||
@ -110,72 +122,4 @@ impl <N, B> LayerContext<N, B> for CommonLayerContext<N, B>
|
||||
fn deltas(&self) -> &B::Tensor {
|
||||
&self.deltas
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LayerImpl <N, B, O, L>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
pub layer: L,
|
||||
initialized: bool,
|
||||
_m: PhantomData<fn(N, B, O)>,
|
||||
}
|
||||
|
||||
impl <N, B, O, L> LayerImpl<N, B, O, L>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B> + Optimizable<N, B, O>
|
||||
{
|
||||
pub fn new(layer: L) -> Self {
|
||||
Self {
|
||||
layer,
|
||||
initialized: false,
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O, L> core::fmt::Display for LayerImpl<N, B, O, L>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B>
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{} -> ", self.layer.input_shape())?;
|
||||
write!(f, "{}", self.layer.name())?;
|
||||
writeln!(f, " -> {}", self.layer.output_shape())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O, L> AbstractLayer<N, B, O> for LayerImpl<N, B, O, L>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B> + Optimizable<N, B, O>
|
||||
{
|
||||
type Context = CommonLayerContext<N, B>;
|
||||
|
||||
#[inline]
|
||||
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
if !self.initialized {
|
||||
self.initialized = true;
|
||||
self.layer.init(&backend);
|
||||
}
|
||||
|
||||
ctx.update_outputs_bs(inputs.shape().get(0), &self.layer.output_shape());
|
||||
self.layer.forward(&backend, &mut ctx.outputs, inputs);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_bs(deltas.shape().get(0), &self.layer.input_shape());
|
||||
self.layer.backward(&backend, &mut ctx.deltas, deltas, inputs, &ctx.outputs);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, _ctx: &mut Self::Context) {
|
||||
self.layer.calc_gradients(&backend, inputs, deltas);
|
||||
self.layer.optimize(&backend, &optimizer);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,20 +1,20 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::backend::{Backend, PaddingKind, BackendAvgPool2d, Conv2dInfo};
|
||||
use crate::optimizer::Optimizer;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct AvgPool2dConfig {
|
||||
pub pool: (u32, u32),
|
||||
pub strides: (u32, u32),
|
||||
pub padding: PaddingKind,
|
||||
pub strides: Option<(u32, u32)>,
|
||||
}
|
||||
|
||||
impl Default for AvgPool2dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
pool: (2, 2),
|
||||
strides: (2, 2),
|
||||
padding: PaddingKind::Valid,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -27,29 +27,16 @@ pub struct AvgPool2d<N, B>
|
||||
_m: PhantomData<fn(N, B)>
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for AvgPool2d<N, B>
|
||||
where B: Backend<N> + BackendAvgPool2d<N>
|
||||
impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
|
||||
where B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = AvgPool2dConfig;
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"AvgPool2d"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
AvgPool2d {
|
||||
input_shape,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: config.pool,
|
||||
strides: config.strides,
|
||||
padding: config.padding,
|
||||
},
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -72,7 +59,11 @@ impl <N, B> Layer<N, B> for AvgPool2d<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
|
||||
let y = &mut ctx.outputs;
|
||||
|
||||
assert_eq!(y.shape().dims, 4);
|
||||
assert_eq!(x.shape().dims, 4);
|
||||
|
||||
@ -80,10 +71,35 @@ impl <N, B> Layer<N, B> for AvgPool2d<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
let dx = &mut ctx.deltas;
|
||||
|
||||
assert_eq!(dy.shape().dims, 4);
|
||||
assert_eq!(dx.shape().dims, 4);
|
||||
|
||||
backend.avg_pool2d_backprop(dx, dy, x, &self.conv_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
|
||||
where B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = AvgPool2dConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
AvgPool2d {
|
||||
input_shape,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: config.pool,
|
||||
strides: config.strides.unwrap_or(config.pool),
|
||||
padding: PaddingKind::Valid,
|
||||
},
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::layer::{AbstractLayer, LayerContext};
|
||||
use crate::layer::{Layer, LayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::TensorShape;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
@ -47,8 +48,8 @@ impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
|
||||
pub struct Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: AbstractLayer<N, B, O>,
|
||||
R: AbstractLayer<N, B, O>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
left: L,
|
||||
right: R,
|
||||
@ -58,8 +59,8 @@ pub struct Chain<N, B, O, L, R>
|
||||
impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: AbstractLayer<N, B, O>,
|
||||
R: AbstractLayer<N, B, O>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
pub fn new(left: L, right: R) -> Self {
|
||||
Self {
|
||||
@ -70,30 +71,56 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
|
||||
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
|
||||
// where B: Backend<N>,
|
||||
// O: Optimizer<N, B>,
|
||||
// L: Layer<N, B, O>,
|
||||
// R: Layer<N, B, O>,
|
||||
// {
|
||||
// fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
// self.left.fmt(f)?;
|
||||
// self.right.fmt(f)?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: AbstractLayer<N, B, O>,
|
||||
R: AbstractLayer<N, B, O>,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
self.left.fmt(f)?;
|
||||
self.right.fmt(f)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O, L, R> AbstractLayer<N, B, O> for Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: AbstractLayer<N, B, O>,
|
||||
R: AbstractLayer<N, B, O>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
type Context = ChainContext<N, B, L::Context, R::Context>;
|
||||
|
||||
#[inline]
|
||||
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
fn name(&self) -> &str {
|
||||
"Chain Layer"
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
self.left.param_count() + self.right.param_count()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.left.init(backend);
|
||||
self.right.init(backend);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.left.input_shape()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn output_shape(&self) -> TensorShape {
|
||||
self.right.output_shape()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.left.forward(backend, inputs, &mut ctx.left);
|
||||
self.right.forward(backend, ctx.left.outputs(), &mut ctx.right);
|
||||
}
|
||||
@ -105,8 +132,21 @@ impl<N, B, O, L, R> AbstractLayer<N, B, O> for Chain<N, B, O, L, R>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.left.update(backend, optimizer, inputs, ctx.right.deltas(), &mut ctx.left);
|
||||
self.right.update(backend, optimizer, ctx.left.outputs(), deltas, &mut ctx.right);
|
||||
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.left.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
|
||||
self.right.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, backend: &B, optimizer: &O) {
|
||||
self.left.optimize(backend, optimizer);
|
||||
self.right.optimize(backend, optimizer);
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
|
||||
self.left.fmt(f, padding)?;
|
||||
self.right.fmt(f, padding)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::params::Params;
|
||||
use crate::backend::{Backend, Conv2dInfo, PaddingKind, BackendBias, BackendConv2d, BackendScale};
|
||||
use crate::optimizer::{Optimizable, Optimizer};
|
||||
use crate::optimizer::Optimizer;
|
||||
|
||||
pub struct Conv2dConfig {
|
||||
pub filters: u32,
|
||||
@ -36,32 +36,24 @@ pub struct Conv2d<N, B, O>
|
||||
biases: Params<N, B, O>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N>,
|
||||
impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = Conv2dConfig;
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Conv2d"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
Conv2d {
|
||||
input_shape,
|
||||
units: cfg.filters,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: cfg.kernel,
|
||||
padding: cfg.padding,
|
||||
strides: cfg.strides,
|
||||
},
|
||||
use_biases: cfg.biases,
|
||||
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
|
||||
biases: Params::new((cfg.filters, )),
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
if self.use_biases {
|
||||
self.filters.params.shape().size() + self.biases.params.shape().size()
|
||||
} else {
|
||||
self.filters.params.shape().size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.filters.init_random(backend, self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0));
|
||||
@ -93,11 +85,12 @@ impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
assert_eq!(y.shape().dims, 4);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
assert_eq!(x.shape().dims, 4);
|
||||
|
||||
backend.conv2d_forward(y, x, &self.filters.params, &self.conv_info);
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
|
||||
|
||||
backend.conv2d_forward(&mut ctx.outputs, x, &self.filters.params, &self.conv_info);
|
||||
|
||||
if self.use_biases {
|
||||
unimplemented!();
|
||||
@ -106,20 +99,16 @@ impl <N, B, O> Layer<N, B> for Conv2d<N, B, O>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
assert_eq!(dy.shape().dims, 4);
|
||||
assert_eq!(dx.shape().dims, 4);
|
||||
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.conv2d_backward_input(dx, dy, &self.filters.params, &self.conv_info);
|
||||
backend.conv2d_backward_input(&mut ctx.deltas, dy, &self.filters.params, &self.conv_info);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> Optimizable<N, B, O> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
#[inline]
|
||||
fn calc_gradients(&mut self, backend: &B, x: &B::Tensor, dy: &B::Tensor) {
|
||||
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
|
||||
assert_eq!(dy.shape().dims, 4);
|
||||
assert_eq!(x.shape().dims, 4);
|
||||
|
||||
@ -138,4 +127,28 @@ impl <N, B, O> Optimizable<N, B, O> for Conv2d<N, B, O>
|
||||
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = Conv2dConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
Conv2d {
|
||||
input_shape,
|
||||
units: cfg.filters,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: cfg.kernel,
|
||||
padding: cfg.padding,
|
||||
strides: cfg.strides,
|
||||
},
|
||||
use_biases: cfg.biases,
|
||||
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
|
||||
biases: Params::new((cfg.filters, )),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,8 +1,10 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::tensor::{TensorShape, Tensor};
|
||||
use crate::backend::{Backend, BackendCopy};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FlattenConfig;
|
||||
|
||||
@ -13,22 +15,16 @@ pub struct Flatten<N, B>
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for Flatten<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>
|
||||
impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = FlattenConfig;
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Flatten"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Flatten {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -40,12 +36,30 @@ impl <N, B> Layer<N, B> for Flatten<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.copy(y, x);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
|
||||
backend.copy(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
|
||||
backend.copy(dx, dy);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.copy(&mut ctx.deltas, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Flatten<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = FlattenConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Flatten {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::params::Params;
|
||||
use crate::backend::{Backend, BackendGemm, BackendBias, BackendScale};
|
||||
use crate::optimizer::{Optimizable, Optimizer};
|
||||
use crate::optimizer::Optimizer;
|
||||
|
||||
pub struct LinearConfig {
|
||||
pub units: u32,
|
||||
@ -29,30 +29,25 @@ pub struct Linear<N, B, O>
|
||||
biases: Params<N, B, O>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N>,
|
||||
impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = LinearConfig;
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Linear"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 1);
|
||||
|
||||
let inputs = input_shape.get(0);
|
||||
|
||||
Linear {
|
||||
inputs,
|
||||
outputs: cfg.units,
|
||||
use_biases: cfg.biases,
|
||||
weights: Params::new((inputs, cfg.units)),
|
||||
biases: Params::new((cfg.units, )),
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
if self.use_biases {
|
||||
self.weights.params.shape().size() + self.biases.params.shape().size()
|
||||
} else {
|
||||
self.weights.params.shape().size()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.weights.init_random(backend, self.inputs + self.outputs);
|
||||
if self.use_biases {
|
||||
@ -71,33 +66,32 @@ impl <N, B, O> Layer<N, B> for Linear<N, B, O>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.matmul(y, x, &self.weights.params);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
|
||||
|
||||
backend.matmul(&mut ctx.outputs, x, &self.weights.params);
|
||||
|
||||
if self.use_biases {
|
||||
backend.bias_add(y, &self.biases.params);
|
||||
backend.bias_add(&mut ctx.outputs, &self.biases.params);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
|
||||
backend.matmul_nt(dx, dy, &self.weights.params);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape());
|
||||
|
||||
backend.matmul_nt(&mut ctx.deltas, dy, &self.weights.params);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> Optimizable<N, B, O> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor) {
|
||||
let prescaler = 1.0 / inputs.shape().get(0) as f32;
|
||||
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
|
||||
let prescaler = 1.0 / x.shape().get(0) as f32;
|
||||
|
||||
backend.matmul_tn(&mut self.weights.grads, inputs, deltas);
|
||||
backend.matmul_tn(&mut self.weights.grads, x, dy);
|
||||
backend.scale(&mut self.weights.grads, backend.scalar_f32(prescaler));
|
||||
|
||||
if self.use_biases {
|
||||
backend.scale(&mut self.biases.grads, backend.scalar_f32(prescaler));
|
||||
backend.bias_grad(&mut self.biases.grads, deltas);
|
||||
backend.bias_grad(&mut self.biases.grads, &dy);
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,4 +103,25 @@ impl <N, B, O> Optimizable<N, B, O> for Linear<N, B, O>
|
||||
optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &mut self.biases.grads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = LinearConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 1);
|
||||
|
||||
let inputs = input_shape.get(0);
|
||||
|
||||
Linear {
|
||||
inputs,
|
||||
outputs: cfg.units,
|
||||
use_biases: cfg.biases,
|
||||
weights: Params::new((inputs, cfg.units)),
|
||||
biases: Params::new((cfg.units, )),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::layer::Layer;
|
||||
use crate::tensor::{TensorShape, Tensor};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct MaxPool2dConfig {
|
||||
@ -25,29 +26,16 @@ pub struct MaxPool2d<N, B>
|
||||
_m: PhantomData<fn(N, B)>
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for MaxPool2d<N, B>
|
||||
impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
|
||||
where B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = MaxPool2dConfig;
|
||||
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"MaxPool2d"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
MaxPool2d {
|
||||
input_shape,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: config.pool,
|
||||
strides: config.strides.unwrap_or(config.pool),
|
||||
padding: PaddingKind::Valid,
|
||||
},
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -70,12 +58,37 @@ impl <N, B> Layer<N, B> for MaxPool2d<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.max_pool2d(y, x, &self.conv_info)
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
|
||||
backend.max_pool2d(&mut ctx.outputs, x, &self.conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
|
||||
backend.max_pool2d_backprop(dx, dy, x, &self.conv_info);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.max_pool2d_backprop(&mut ctx.deltas, dy, x, &self.conv_info);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
|
||||
where B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = MaxPool2dConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
assert!(input_shape.dims == 3);
|
||||
|
||||
MaxPool2d {
|
||||
input_shape,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: config.pool,
|
||||
strides: config.strides.unwrap_or(config.pool),
|
||||
padding: PaddingKind::Valid,
|
||||
},
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendReLu};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -13,34 +14,46 @@ pub struct ReLu<N, B>
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for ReLu<N, B>
|
||||
where B: Backend<N> + BackendReLu<N>
|
||||
impl <N, B, O> Layer<N, B, O> for ReLu<N, B>
|
||||
where B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = ReLuConfig;
|
||||
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ReLU"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
ReLu {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.relu(y, x);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.relu(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) {
|
||||
backend.relu_grad(dx, y, dy);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.relu_grad(&mut ctx.deltas, &ctx.outputs, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for ReLu<N, B>
|
||||
where B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = ReLuConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
ReLu {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendSigmoid};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -13,21 +14,15 @@ pub struct Sigmoid<N, B>
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for Sigmoid<N, B>
|
||||
where B: Backend<N> + BackendSigmoid<N>
|
||||
{
|
||||
type Config = SigmoidConfig;
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
|
||||
where B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Sigmoid"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Sigmoid {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
@ -35,12 +30,30 @@ impl <N, B> Layer<N, B> for Sigmoid<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.sigmoid(y, x);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.sigmoid(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _x: &B::Tensor, y: &B::Tensor) {
|
||||
backend.sigmoid_grad(dx, y, dy);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.sigmoid_grad(&mut ctx.deltas, &ctx.outputs, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
|
||||
where B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = SigmoidConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Sigmoid {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendSoftmax};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -13,34 +14,46 @@ pub struct Softmax<N, B>
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for Softmax<N, B>
|
||||
where B: Backend<N> + BackendSoftmax<N>
|
||||
{
|
||||
type Config = SoftmaxConfig;
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
|
||||
where B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Softmax"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Softmax {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
|
||||
backend.softmax(y, x);
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.softmax(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
|
||||
backend.copy(dx, dy);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.copy(&mut ctx.deltas, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Softmax<N, B>
|
||||
where B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = SoftmaxConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Softmax {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendCopy};
|
||||
use crate::layer::Layer;
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -16,23 +17,16 @@ pub struct ZeroPadding2d<N, B>
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B> Layer<N, B> for ZeroPadding2d<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>
|
||||
impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = ZeroPadding2dConfig;
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ZeroPadding2d"
|
||||
}
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
ZeroPadding2d {
|
||||
input_shape,
|
||||
config,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -50,12 +44,30 @@ impl <N, B> Layer<N, B> for ZeroPadding2d<N, B>
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, _backend: &B, _y: &mut B::Tensor, _x: &B::Tensor) {
|
||||
fn forward(&self, _backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
// backend.copy_with_padding(y, x, 0.0, (self.config.0, self.config.1, self.config.0, self.config.1));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&self, _backend: &B, _dx: &mut B::Tensor, _dy: &B::Tensor, _: &B::Tensor, _: &B::Tensor) {
|
||||
fn backward(&mut self, _backend: &B, dy: &B::Tensor, _x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
|
||||
// backend.copy(dx, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
type Config = ZeroPadding2dConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, config: Self::Config) -> Self {
|
||||
ZeroPadding2d {
|
||||
input_shape,
|
||||
config,
|
||||
_x: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -8,22 +8,17 @@ macro_rules! sequential_type {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! sequential_type_impl {
|
||||
($t:ty {$($tt:tt)*}) => {
|
||||
$crate::layer::LayerImpl<N, B, O, $t>
|
||||
};
|
||||
($t:ty {$($tt:tt)*}) => ($t);
|
||||
|
||||
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$crate::layer::LayerImpl<N, B, O, $t>,
|
||||
$crate::sequential_type_impl!($($tt)*)
|
||||
$t, $crate::sequential_type_impl!($($tt)*)
|
||||
>
|
||||
};
|
||||
($t:ty) => {
|
||||
$crate::layer::LayerImpl<N, B, O, $t>
|
||||
};
|
||||
($t:ty) => ($t);
|
||||
($t:ty, $($tt:tt)*) => {
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$crate::layer::LayerImpl<N, B, O, $t>,
|
||||
$crate::sequential_type_impl!($($tt)*)
|
||||
$t, $crate::sequential_type_impl!($($tt)*)
|
||||
>
|
||||
};
|
||||
}
|
||||
@ -40,35 +35,29 @@ macro_rules! sequential {
|
||||
#[macro_export]
|
||||
macro_rules! sequential_impl {
|
||||
($p:expr, $t:ty { $($name:ident : $val:expr),* }) => {{
|
||||
#[allow(unused_imports)]
|
||||
use core::convert::TryInto;
|
||||
|
||||
#[allow(unused_mut)]
|
||||
let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default();
|
||||
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
|
||||
$(
|
||||
params.$name = ($val).try_into().unwrap_or($val);
|
||||
params.$name = core::convert::TryInto::try_into($val).unwrap_or($val);
|
||||
)*
|
||||
|
||||
$crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create(
|
||||
<$t as $crate::layer::LayerExt<N, B, O>>::create(
|
||||
$p, params
|
||||
))
|
||||
)
|
||||
}};
|
||||
|
||||
($p:expr, $t:ty { $($name:ident : $val:expr),* }, $($tt:tt)*) => {{
|
||||
#[allow(unused_imports)]
|
||||
use core::convert::TryInto;
|
||||
|
||||
#[allow(unused_mut)]
|
||||
let mut params = <$t as $crate::layer::Layer<_, _>>::Config::default();
|
||||
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
|
||||
$(
|
||||
params.$name = ($val).try_into().unwrap_or($val);;
|
||||
params.$name = core::convert::TryInto::try_into($val).unwrap_or($val);
|
||||
)*
|
||||
|
||||
let layer = $crate::layer::LayerImpl::new(<$t as $crate::layer::Layer<_, _>>::create(
|
||||
let layer = <$t as $crate::layer::LayerExt<N, B, O>>::create(
|
||||
$p, params
|
||||
));
|
||||
);
|
||||
|
||||
let prev_shape = layer.layer.output_shape();
|
||||
let prev_shape = $crate::layer::Layer::<N, B, O>::output_shape(&layer);
|
||||
|
||||
$crate::layers::Chain::new(
|
||||
layer, $crate::sequential_impl! { prev_shape, $($tt)* },
|
||||
@ -94,12 +83,12 @@ macro_rules! sequential_type_ctx {
|
||||
#[macro_export]
|
||||
macro_rules! sequential_type_ctx_impl {
|
||||
($t:ty {$($xx:tt)*}) => {
|
||||
$crate::layer::CommonLayerContext<N, B>
|
||||
$crate::layer::DefaultLayerContext<N, B>
|
||||
};
|
||||
|
||||
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
|
||||
$crate::layers::ChainContext<N, B,
|
||||
$crate::layer::CommonLayerContext<N, B>,
|
||||
$crate::layer::DefaultLayerContext<N, B>,
|
||||
$crate::sequential_type_ctx_impl!($($tt)*)
|
||||
>
|
||||
};
|
||||
@ -149,38 +138,86 @@ macro_rules! model_impl {
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
writeln!(f, "{} {{", stringify!($name))?;
|
||||
write!(f, "{}", self.inner)?;
|
||||
writeln!(f, "}}")?;
|
||||
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
// where B: $crate::backend::Backend<N> + $trait,
|
||||
// O: $crate::optimizer::Optimizer<N, B>
|
||||
// {
|
||||
// fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
// writeln!(f, "{} {{", stringify!($name))?;
|
||||
// write!(f, "{}", self.inner)?;
|
||||
// writeln!(f, "}}")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<N, B, O> $crate::layer::AbstractLayer<N, B, O> for $name<N, B, O>
|
||||
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
type Context = ctx::$name<N, B>;
|
||||
|
||||
#[inline]
|
||||
fn forward(&mut self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
$crate::layer::AbstractLayer::forward(&mut self.inner, backend, inputs, ctx)
|
||||
fn name(&self) -> &str {
|
||||
stringify!($name)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.inner.init(backend);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
self.inner.param_count()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> $crate::tensor::TensorShape {
|
||||
self.inner.input_shape()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn output_shape(&self) -> $crate::tensor::TensorShape {
|
||||
self.inner.output_shape()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.inner.forward(backend, inputs, ctx);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
$crate::layer::AbstractLayer::backward(&mut self.inner, backend, deltas, inputs, ctx);
|
||||
self.inner.backward(backend, deltas, inputs, ctx);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update(&mut self, backend: &B, optimizer: &O, inputs: &B::Tensor, deltas: &B::Tensor, ctx: &mut Self::Context) {
|
||||
$crate::layer::AbstractLayer::update(&mut self.inner, backend, optimizer, inputs, deltas, ctx);
|
||||
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.inner.calc_gradients(backend, deltas, inputs, ctx);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, backend: &B, optimizer: &O) {
|
||||
self.inner.optimize(backend, optimizer);
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
|
||||
writeln!(f, "{}{}[{}] {{", "", self.name(), self.param_count())?;
|
||||
self.inner.fmt(f, padding + 2)?;
|
||||
write!(f, "}}")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
#[inline]
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
<Self as $crate::layer::Layer<_, _, _>>::fmt(self, f, 0)
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -192,6 +229,7 @@ macro_rules! model {
|
||||
mod tmp {
|
||||
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
|
||||
+ $crate::backend::BackendBias<N>
|
||||
+ $crate::backend::BackendScale<N>
|
||||
+ $crate::backend::BackendSigmoid<N>
|
||||
+ $crate::backend::BackendSoftmax<N>
|
||||
+ $crate::backend::BackendGemm<N>
|
||||
|
@ -19,9 +19,4 @@ impl <'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
(**self).update_params(backend, ctx, params, grads)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Optimizable<N, B: Backend<N>, O: Optimizer<N, B>> {
|
||||
fn calc_gradients(&mut self, backend: &B, inputs: &B::Tensor, deltas: &B::Tensor);
|
||||
fn optimize(&mut self, backend: &B, optimizer: &O);
|
||||
}
|
||||
}
|
@ -121,8 +121,8 @@ impl TensorShape {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn append<S: Into<TensorShape>>(&mut self, s: S) -> &mut Self {
|
||||
let s = s.into();
|
||||
pub fn append<S: core::borrow::Borrow<TensorShape>>(&mut self, s: S) -> &mut Self {
|
||||
let s = s.borrow();
|
||||
let sd = self.dims;
|
||||
|
||||
for i in 0 .. s.dims {
|
||||
|
Loading…
Reference in New Issue
Block a user