RMSProp
This commit is contained in:
parent
9cfaa29d42
commit
c86bc4c43d
12
src/main.rs
12
src/main.rs
@ -15,7 +15,7 @@ pub mod losses;
|
||||
pub mod tensor;
|
||||
|
||||
use self::backends::{Native, NativeTensorF32};
|
||||
use self::optimizers::Adam;
|
||||
use self::optimizers::*;
|
||||
use self::layers::*;
|
||||
use self::layer::*;
|
||||
use crate::backend::{Backend, BackendScale};
|
||||
@ -59,16 +59,16 @@ fn main() {
|
||||
|
||||
let backend = Native;
|
||||
// let optimizer = Sgd::new(0.01, 0.1, false);
|
||||
let optimizer = Adam::default();
|
||||
let hidden_count = 64;
|
||||
let optimizer = RMSProp::default();
|
||||
let hidden_count = 16;
|
||||
|
||||
let mut linear_1: LayerImpl<_, _, _, Linear<_, _, &Adam<_, _>>> = LayerImpl::new((784, ).into(), &backend, &optimizer, LinearConfig {
|
||||
let mut linear_1: LayerImpl<_, _, _, Linear<_, _, &RMSProp<_, _>>> = LayerImpl::new((784, ).into(), &backend, &optimizer, LinearConfig {
|
||||
outputs: hidden_count
|
||||
});
|
||||
|
||||
let mut sigmoid_1: LayerImpl<_, _, _, Sigmoid<_, _>> = LayerImpl::new((hidden_count, ).into(), &backend, &optimizer, SigmoidConfig);
|
||||
|
||||
let mut linear_2: LayerImpl<_, _, _, Linear<_, _, &Adam<_, _>>> = LayerImpl::new((hidden_count, ).into(), &backend, &optimizer, LinearConfig {
|
||||
let mut linear_2: LayerImpl<_, _, _, Linear<_, _, &RMSProp<_, _>>> = LayerImpl::new((hidden_count, ).into(), &backend, &optimizer, LinearConfig {
|
||||
outputs: 10
|
||||
});
|
||||
|
||||
@ -114,7 +114,7 @@ fn main() {
|
||||
let mut test_linear_2 = LayerContext::new();
|
||||
let mut test_sigmoid_2 = LayerContext::new();
|
||||
|
||||
for epoch in 1 ..= 80 {
|
||||
for epoch in 1 ..= 100 {
|
||||
println!("epoch {}", epoch);
|
||||
|
||||
for step in 0 .. (60000 / BATCH_SIZE) {
|
||||
|
@ -1,5 +1,7 @@
|
||||
mod sgd;
|
||||
mod adam;
|
||||
mod rmsprop;
|
||||
|
||||
pub use self::sgd::*;
|
||||
pub use self::adam::*;
|
||||
pub use self::adam::*;
|
||||
pub use self::rmsprop::*;
|
68
src/optimizers/rmsprop.rs
Normal file
68
src/optimizers/rmsprop.rs
Normal file
@ -0,0 +1,68 @@
|
||||
use std::marker::PhantomData;
|
||||
use crate::backend::{Backend, BackendAdam};
|
||||
use crate::optimizer::{Optimizer, OptimizerContext};
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use std::cell::Cell;
|
||||
|
||||
|
||||
pub struct RMSPropContext<N, B>
|
||||
where B: Backend<N>
|
||||
{
|
||||
accum: B::Tensor,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B: Backend<N>> OptimizerContext for RMSPropContext<N, B> {
|
||||
fn new<S: Into<TensorShape>>(shape: S) -> Self {
|
||||
let shape = shape.into();
|
||||
|
||||
Self {
|
||||
accum: B::Tensor::new(shape),
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RMSProp<N, B: Backend<N>> {
|
||||
learning_rate: f32,
|
||||
rho: f32,
|
||||
epsilon: Option<f32>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> Default for RMSProp<N, B>
|
||||
where B: Backend<N>
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
learning_rate: 0.001,
|
||||
rho: 0.9,
|
||||
epsilon: None,
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B: Backend<N>> RMSProp<N, B> {
|
||||
pub fn new(learning_rate: f32, rho: f32) -> Self {
|
||||
Self {
|
||||
learning_rate,
|
||||
rho,
|
||||
epsilon: None,
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for RMSProp<N, B> {
|
||||
type Context = RMSPropContext<N, B>;
|
||||
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &B::Tensor) {
|
||||
// new_a = self.rho * a + (1. - self.rho) * K.square(g)
|
||||
backend.scale(&mut ctx.accum, backend.scalar_f32(self.rho));
|
||||
backend.axpys(&mut ctx.accum, backend.scalar_f32(1.0 - self.rho), grads);
|
||||
|
||||
// new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
|
||||
backend.adam_p(params, backend.scalar_f32(-self.learning_rate), &grads, &ctx.accum, backend.scalar_f32(self.epsilon.unwrap_or(std::f32::EPSILON)));
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user