NativeBlas
This commit is contained in:
parent
9117c70844
commit
765a1eccfa
74
Cargo.lock
generated
74
Cargo.lock
generated
@ -25,6 +25,24 @@ dependencies = [
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas-sys"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "2.5.0"
|
||||
@ -44,6 +62,24 @@ dependencies = [
|
||||
"ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cblas"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cblas-sys"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.37"
|
||||
@ -142,6 +178,28 @@ dependencies = [
|
||||
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas-src"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.5"
|
||||
@ -358,6 +416,7 @@ dependencies = [
|
||||
"mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"yarnn 0.1.0",
|
||||
"yarnn-model-mnist 0.1.0",
|
||||
"yarnn-native-blas 0.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -392,17 +451,27 @@ dependencies = [
|
||||
"yarnn 0.1.0",
|
||||
]
|
||||
|
||||
[[patch.unused]]
|
||||
[[package]]
|
||||
name = "yarnn-native-blas"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"yarnn 0.1.0",
|
||||
]
|
||||
|
||||
[metadata]
|
||||
"checksum autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "0e49efa51329a5fd37e7c79db4621af617cd4e3e5bc224939808d076077077bf"
|
||||
"checksum backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)" = "18b50f5258d1a9ad8396d2d345827875de4261b158124d4c819d9b351454fae5"
|
||||
"checksum backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)" = "5b3a000b9c543553af61bc01cbfc403b04b5caa9e421033866f2e98061eb3e61"
|
||||
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
|
||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
"checksum bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2cd43d82f27d68911e6ee11ee791fb248f138f5d69424dc02e098d4f152b0b05"
|
||||
"checksum byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5"
|
||||
"checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101"
|
||||
"checksum cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
|
||||
"checksum cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
"checksum cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "39f75544d7bbaf57560d2168f28fd649ff9c76153874db88bdbdfd839b1a7e7d"
|
||||
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
|
||||
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
|
||||
@ -416,6 +485,9 @@ version = "0.1.0"
|
||||
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
|
||||
"checksum mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "25f19bfda80095b4294000bbb50506f028149ed0ddb7fabf46ebb673b91626bc"
|
||||
"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
|
||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||
"checksum num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32"
|
||||
"checksum openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
|
||||
"checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b"
|
||||
"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
|
||||
"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
members = [
|
||||
"yarnn",
|
||||
# "yarnn-native-blas",
|
||||
"yarnn-native-blas",
|
||||
|
||||
"yarnn-models/mnist",
|
||||
"yarnn-models/vgg16",
|
||||
|
@ -7,4 +7,5 @@ edition = "2018"
|
||||
[dependencies]
|
||||
mnist = "0.4.0"
|
||||
yarnn = "0.1.0"
|
||||
yarnn-model-mnist = "0.1.0"
|
||||
yarnn-model-mnist = "0.1.0"
|
||||
yarnn-native-blas = "0.1.0"
|
@ -1,8 +1,9 @@
|
||||
use yarnn::prelude::*;
|
||||
use yarnn::native::{Native, NativeTensorF32};
|
||||
use yarnn::native::{Native, NativeTensor};
|
||||
use yarnn_model_mnist::*;
|
||||
use yarnn::losses::CrossEntropyLoss;
|
||||
use yarnn::optimizers::Adam;
|
||||
use yarnn_native_blas::NativeBlas;
|
||||
use mnist::{Mnist, MnistBuilder};
|
||||
|
||||
fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -> f32 {
|
||||
@ -38,7 +39,7 @@ fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -
|
||||
fn main() {
|
||||
const BATCH_SIZE: usize = 64;
|
||||
|
||||
let backend = Native;
|
||||
let backend: NativeBlas<f32, Native<_>> = Default::default();
|
||||
let optimizer = Adam::default();
|
||||
|
||||
// let mut model = MnistDenseModel::new(28, 28, 1);
|
||||
@ -57,14 +58,14 @@ fn main() {
|
||||
.label_format_digit()
|
||||
.finalize();
|
||||
|
||||
let mut inputs = NativeTensorF32::new((BATCH_SIZE as u32, 1, 28, 28));
|
||||
let mut targets = NativeTensorF32::new((BATCH_SIZE as u32, 10));
|
||||
let mut deltas = NativeTensorF32::new((BATCH_SIZE as u32, 10));
|
||||
let mut inputs = NativeTensor::new((BATCH_SIZE as u32, 1, 28, 28));
|
||||
let mut targets = NativeTensor::new((BATCH_SIZE as u32, 10));
|
||||
let mut deltas = NativeTensor::new((BATCH_SIZE as u32, 10));
|
||||
|
||||
let test_count = 1000;
|
||||
|
||||
let mut inputs0 = NativeTensorF32::new((test_count as u32, 1, 28, 28));
|
||||
let mut targets0 = NativeTensorF32::new((test_count as u32, 10));
|
||||
let mut inputs0 = NativeTensor::new((test_count as u32, 1, 28, 28));
|
||||
let mut targets0 = NativeTensor::new((test_count as u32, 10));
|
||||
|
||||
let mut tmp = vec![0u8; 10 * test_count];
|
||||
|
||||
|
@ -4,7 +4,7 @@ use yarnn_model_vgg16::Vgg16Model;
|
||||
|
||||
|
||||
fn main() {
|
||||
let vgg16: Vgg16Model<f32, Native, Adam<_, _>> = Vgg16Model::new(224, 224, 3);
|
||||
let vgg16: Vgg16Model<f32, Native<_>, Adam<_, _>> = Vgg16Model::new(224, 224, 3);
|
||||
|
||||
println!("{}", vgg16);
|
||||
}
|
||||
|
@ -1,56 +1,105 @@
|
||||
mod img2col;
|
||||
|
||||
use yarnn::backend::*;
|
||||
use yarnn::native::*;
|
||||
use yarnn::tensor::*;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
extern crate openblas_src;
|
||||
|
||||
pub struct NativeBlas<N, B: Native> {
|
||||
inner: B,
|
||||
}
|
||||
|
||||
impl<N, B: Native> Native for NativeBlas<N, B> {}
|
||||
|
||||
impl<N, B> NativeBlas<N, B>
|
||||
where N: NativeTensor,
|
||||
pub struct NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
{
|
||||
pub fn new(native: B) -> Self {
|
||||
inner: B,
|
||||
_m: PhantomData<fn(N)>
|
||||
}
|
||||
|
||||
impl<N, B> Default for NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inner: native
|
||||
inner: Default::default(),
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> NativeBackend<N> for NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
{
|
||||
#[inline]
|
||||
fn read_tensor<'a>(&self, t: &'a Self::Tensor) -> &'a [N] {
|
||||
self.inner.read_tensor(t)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_tensor<'a>(&self, t: &'a mut Self::Tensor) -> &'a mut [N] {
|
||||
self.inner.write_tensor(t)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
{
|
||||
pub fn new(inner: B) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_m: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> Backend<N> for NativeBlas<N, B>
|
||||
where N: NativeTensor,
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
{
|
||||
type Tensor = B::Tensor;
|
||||
|
||||
#[inline]
|
||||
fn store_tensor_f32(&self, t: &Self::Tensor, data: &mut [f32]) {
|
||||
|
||||
self.inner.store_tensor_f32(t, data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn load_tensor_u8(&self, t: &mut Self::Tensor, data: &[u8]) {
|
||||
|
||||
self.inner.load_tensor_u8(t, data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn load_tensor_f32(&self, t: &mut Self::Tensor, data: &[f32]) {
|
||||
|
||||
self.inner.load_tensor_f32(t, data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn scalar_f32(&self, val: f32) -> N {
|
||||
|
||||
N::from_f32(val)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_scalar(&self, t: &mut Self::Tensor, scalar: N) {
|
||||
|
||||
self.inner.fill_scalar(t, scalar)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_random(&self, t: &mut Self::Tensor, from: N, to: N) {
|
||||
|
||||
self.inner.fill_random(t, from, to)
|
||||
}
|
||||
fn print_tensor(&self, t: &Self::Tensor) {
|
||||
|
||||
#[inline]
|
||||
fn print_tensor(&self, t: &Self::Tensor) {
|
||||
self.inner.print_tensor(t)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
{
|
||||
#[inline]
|
||||
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
let a_shape = a.shape();
|
||||
let b_shape = b.shape();
|
||||
@ -62,21 +111,22 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
assert_eq!(a_shape.dims, 2);
|
||||
assert_eq!(b_shape.dims, 2);
|
||||
|
||||
let m = a_shape.get(0) as usize;
|
||||
let n = b_shape.get(1) as usize;
|
||||
let k = b_shape.get(0) as usize;
|
||||
let m = a_shape.get(0) as i32;
|
||||
let n = b_shape.get(1) as i32;
|
||||
let k = b_shape.get(0) as i32;
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('N' as u8, 'N' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
b.read(), n,
|
||||
a.read(), k,
|
||||
self.read_tensor(b), n,
|
||||
self.read_tensor(a), k,
|
||||
0.0,
|
||||
&mut dst.write(), n);
|
||||
self.write_tensor(dst), n);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn matmul_nt(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
let a_shape = a.shape();
|
||||
let b_shape = b.shape();
|
||||
@ -88,21 +138,22 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
assert_eq!(a_shape.dims, 2);
|
||||
assert_eq!(b_shape.dims, 2);
|
||||
|
||||
let m = a_shape.get(0) as usize;
|
||||
let n = b_shape.get(0) as usize;
|
||||
let k = b_shape.get(1) as usize;
|
||||
let m = a_shape.get(0) as i32;
|
||||
let n = b_shape.get(0) as i32;
|
||||
let k = b_shape.get(1) as i32;
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('T' as u8, 'N' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
b.read(), k,
|
||||
a.read(), k,
|
||||
self.read_tensor(b), k,
|
||||
self.read_tensor(a), k,
|
||||
0.0,
|
||||
&mut dst.write(), n);
|
||||
self.write_tensor(dst), n);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn matmul_tn(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
let a_shape = a.shape();
|
||||
let b_shape = b.shape();
|
||||
@ -114,21 +165,22 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
assert_eq!(a_shape.dims, 2);
|
||||
assert_eq!(b_shape.dims, 2);
|
||||
|
||||
let m = a_shape.get(1) as usize;
|
||||
let n = b_shape.get(1) as usize;
|
||||
let k = b_shape.get(0) as usize;
|
||||
let m = a_shape.get(1) as i32;
|
||||
let n = b_shape.get(1) as i32;
|
||||
let k = b_shape.get(0) as i32;
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('N' as u8, 'T' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
b.read(), n,
|
||||
a.read(), m,
|
||||
self.read_tensor(b), n,
|
||||
self.read_tensor(a), m,
|
||||
0.0,
|
||||
&mut dst.write(), n);
|
||||
self.write_tensor(dst), n);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn matmul_tt(&self, _dst: &mut Self::Tensor, _a: &Self::Tensor, _b: &Self::Tensor) {
|
||||
unimplemented!();
|
||||
}
|
||||
@ -137,6 +189,7 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
{
|
||||
#[inline]
|
||||
fn axpy(&self, dst: &mut Self::Tensor, scale: f32, x: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -146,9 +199,9 @@ impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
blas::saxpy(
|
||||
dst_size as i32,
|
||||
scale,
|
||||
x.read(),
|
||||
self.read_tensor(x),
|
||||
1,
|
||||
dst.write(),
|
||||
self.write_tensor(dst),
|
||||
1
|
||||
);
|
||||
}
|
||||
@ -158,6 +211,7 @@ impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
impl<B> BackendScale<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
{
|
||||
#[inline]
|
||||
fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -165,9 +219,166 @@ impl<B> BackendScale<f32> for NativeBlas<f32, B>
|
||||
blas::sscal(
|
||||
dst_size as i32,
|
||||
scale,
|
||||
dst.write(),
|
||||
self.write_tensor(dst),
|
||||
1
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendSigmoid<f32>> BackendSigmoid<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
self.inner.sigmoid(dst, data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn sigmoid_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor) {
|
||||
self.inner.sigmoid_grad(dst, z, d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendReLu<f32>> BackendReLu<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
self.inner.relu(dst, data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn relu_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor) {
|
||||
self.inner.relu_grad(dst, z, d)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendBias<f32>> BackendBias<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn bias_add(&self, dst: &mut Self::Tensor, biases: &Self::Tensor) {
|
||||
self.inner.bias_add(dst, biases)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn bias_grad(&self, dbiases: &mut Self::Tensor, deltas: &Self::Tensor) {
|
||||
self.inner.bias_grad(dbiases, deltas)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendMse<f32>> BackendMse<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: f32) {
|
||||
self.inner.scaled_square_diff(dst, a, b, scale)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn scaled_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: f32) {
|
||||
self.inner.scaled_diff(dst, a, b, scale)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendAxpys<f32>> BackendAxpys<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn axpys(&self, dst: &mut Self::Tensor, scale: f32, a: &Self::Tensor) {
|
||||
self.inner.axpys(dst, scale, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendAdd<f32>> BackendAdd<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
self.inner.add(dst, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendSub<f32>> BackendSub<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
self.inner.sub(dst, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendMul<f32>> BackendMul<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
self.inner.mul(dst, a)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendCopy<f32>> BackendCopy<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
self.inner.copy(dst, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendMaximum<f32>> BackendMaximum<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
self.inner.maximum(dst, a)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendAdam<f32>> BackendAdam<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn adam_p(&self, dst: &mut Self::Tensor, lr: f32, moms: &Self::Tensor, vels: &Self::Tensor, eps: f32) {
|
||||
self.inner.adam_p(dst, lr, moms, vels, eps)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendSoftmax<f32>> BackendSoftmax<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor) {
|
||||
self.inner.softmax(y, x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendConv2d<f32>> BackendConv2d<f32> for NativeBlas<f32, B> {
|
||||
type Context = ();
|
||||
|
||||
#[inline]
|
||||
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.conv2d_forward(y, x, w, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.conv2d_backward_input(dx, dy, w, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.conv2d_backward_filter(dw, x, dy, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendMaxPool2d<f32>> BackendMaxPool2d<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.max_pool2d(y, x, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.max_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendAvgPool2d<f32>> BackendAvgPool2d<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.avg_pool2d(y, x, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.avg_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendPaddingCopy2d<f32>> BackendPaddingCopy2d<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
|
||||
self.inner.copy_with_padding2d(y, x, y_paddings, x_paddings)
|
||||
}
|
||||
}
|
||||
|
@ -24,4 +24,4 @@ pub mod prelude {
|
||||
pub use super::loss::*;
|
||||
pub use super::tensor::*;
|
||||
pub use super::layer::*;
|
||||
}
|
||||
}
|
@ -1,10 +1,8 @@
|
||||
mod kernel_3x3;
|
||||
mod kernel_5x5;
|
||||
mod img2col;
|
||||
|
||||
pub use self::kernel_3x3::*;
|
||||
pub use self::kernel_5x5::*;
|
||||
pub use self::img2col::*;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
|
@ -14,28 +14,72 @@ use core::fmt::Write;
|
||||
use rand_distr::{Normal, Distribution};
|
||||
|
||||
|
||||
pub struct NativeTensorF32 {
|
||||
shape: TensorShape,
|
||||
ptr: Option<Box<[f32]>>
|
||||
pub trait NativeNumber: Copy {
|
||||
fn from_f32(val: f32) -> Self;
|
||||
fn from_f64(val: f64) -> Self;
|
||||
fn from_i64(val: i64) -> Self;
|
||||
fn from_i32(val: i32) -> Self;
|
||||
fn from_i16(val: i16) -> Self;
|
||||
fn from_i8(val: i8) -> Self;
|
||||
fn from_u64(val: u64) -> Self;
|
||||
fn from_u32(val: u32) -> Self;
|
||||
fn from_u16(val: u16) -> Self;
|
||||
fn from_u8(val: u8) -> Self;
|
||||
}
|
||||
|
||||
impl NativeTensorF32 {
|
||||
pub fn read(&self) -> &[f32] {
|
||||
impl NativeNumber for f32 {
|
||||
fn from_f32(val: f32) -> Self { val as f32 }
|
||||
fn from_f64(val: f64) -> Self { val as f32 }
|
||||
fn from_i64(val: i64) -> Self { val as f32 }
|
||||
fn from_i32(val: i32) -> Self { val as f32 }
|
||||
fn from_i16(val: i16) -> Self { val as f32 }
|
||||
fn from_i8(val: i8) -> Self { val as f32 }
|
||||
fn from_u64(val: u64) -> Self { val as f32 }
|
||||
fn from_u32(val: u32) -> Self { val as f32 }
|
||||
fn from_u16(val: u16) -> Self { val as f32 }
|
||||
fn from_u8(val: u8) -> Self { val as f32 }
|
||||
}
|
||||
|
||||
impl NativeNumber for f64 {
|
||||
fn from_f32(val: f32) -> Self { val as f64 }
|
||||
fn from_f64(val: f64) -> Self { val as f64 }
|
||||
fn from_i64(val: i64) -> Self { val as f64 }
|
||||
fn from_i32(val: i32) -> Self { val as f64 }
|
||||
fn from_i16(val: i16) -> Self { val as f64 }
|
||||
fn from_i8(val: i8) -> Self { val as f64 }
|
||||
fn from_u64(val: u64) -> Self { val as f64 }
|
||||
fn from_u32(val: u32) -> Self { val as f64 }
|
||||
fn from_u16(val: u16) -> Self { val as f64 }
|
||||
fn from_u8(val: u8) -> Self { val as f64 }
|
||||
}
|
||||
|
||||
pub trait NativeBackend<N: NativeNumber>: Backend<N> + Default {
|
||||
fn read_tensor<'a>(&self, t: &'a Self::Tensor) -> &'a [N];
|
||||
fn write_tensor<'a>(&self, t: &'a mut Self::Tensor) -> &'a mut [N];
|
||||
}
|
||||
|
||||
pub struct NativeTensor<N: NativeNumber> {
|
||||
shape: TensorShape,
|
||||
ptr: Option<Box<[N]>>
|
||||
}
|
||||
|
||||
impl<N: NativeNumber> NativeTensor<N> {
|
||||
pub fn read(&self) -> &[N] {
|
||||
self.ptr.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn write(&mut self) -> &mut [f32] {
|
||||
pub fn write(&mut self) -> &mut [N] {
|
||||
if self.ptr.is_none() {
|
||||
self.ptr = Some(vec![0.0; self.shape.size()].into_boxed_slice());
|
||||
self.ptr = Some(vec![N::from_f32(0.0); self.shape.size()].into_boxed_slice());
|
||||
}
|
||||
|
||||
return self.ptr.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor<f32> for NativeTensorF32 {
|
||||
impl<N: NativeNumber> Tensor<N> for NativeTensor<N> {
|
||||
fn new<S: Into<TensorShape>>(shape: S) -> Self {
|
||||
NativeTensorF32 {
|
||||
NativeTensor {
|
||||
shape: shape.into(),
|
||||
ptr: None,
|
||||
}
|
||||
@ -46,24 +90,16 @@ impl Tensor<f32> for NativeTensorF32 {
|
||||
}
|
||||
|
||||
fn resize(&mut self, shape: TensorShape) {
|
||||
self.ptr = if let Some(ptr) = self.ptr.take() {
|
||||
let size = self.shape.size();
|
||||
let raw = Box::into_raw(ptr) as *mut f32;
|
||||
let mut data = unsafe {Vec::from_raw_parts(raw, size, size)};
|
||||
data.resize(shape.size(), 0.0);
|
||||
|
||||
Some(data.into_boxed_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.ptr = None;
|
||||
self.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Native;
|
||||
#[derive(Default)]
|
||||
pub struct Native<N: NativeNumber>(core::marker::PhantomData<N>);
|
||||
|
||||
impl Native {
|
||||
fn fmt_tensor(&self, t: &NativeTensorF32, f: &mut String) -> fmt::Result {
|
||||
impl<N: NativeNumber + core::fmt::Display> Native<N> {
|
||||
fn fmt_tensor(&self, t: &NativeTensor<N>, f: &mut String) -> fmt::Result {
|
||||
let strides = t.shape.default_strides();
|
||||
let last_idx = strides.dims - 1;
|
||||
writeln!(f, "default stridses {} {}", t.shape.default_strides(), last_idx)?;
|
||||
@ -97,8 +133,8 @@ impl Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend<f32> for Native {
|
||||
type Tensor = NativeTensorF32;
|
||||
impl Backend<f32> for Native<f32> {
|
||||
type Tensor = NativeTensor<f32>;
|
||||
|
||||
fn store_tensor_f32(&self, t: &Self::Tensor, data: &mut [f32]) {
|
||||
let size = t.shape().size();
|
||||
@ -170,8 +206,19 @@ impl Backend<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeBackend<f32> for Native<f32> {
|
||||
#[inline]
|
||||
fn read_tensor<'a>(&self, t: &'a Self::Tensor) -> &'a [f32] {
|
||||
t.read()
|
||||
}
|
||||
|
||||
impl BackendGemm<f32> for Native {
|
||||
#[inline]
|
||||
fn write_tensor<'a>(&self, t: &'a mut Self::Tensor) -> &'a mut [f32] {
|
||||
t.write()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendGemm<f32> for Native<f32> {
|
||||
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
let a_shape = a.shape();
|
||||
let b_shape = b.shape();
|
||||
@ -251,7 +298,7 @@ impl BackendGemm<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendSigmoid<f32> for Native {
|
||||
impl BackendSigmoid<f32> for Native<f32> {
|
||||
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -281,7 +328,7 @@ impl BackendSigmoid<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendReLu<f32> for Native {
|
||||
impl BackendReLu<f32> for Native<f32> {
|
||||
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -321,7 +368,7 @@ impl BackendReLu<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendBias<f32> for Native {
|
||||
impl BackendBias<f32> for Native<f32> {
|
||||
fn bias_add(&self, dst: &mut Self::Tensor, biases: &Self::Tensor) {
|
||||
let biases_shape = biases.shape();
|
||||
let dst_shape = dst.shape().clone();
|
||||
@ -391,7 +438,7 @@ impl BackendBias<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendScale<f32> for Native {
|
||||
impl BackendScale<f32> for Native<f32> {
|
||||
fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
|
||||
let dst_size = dst.shape().size();
|
||||
let dst_s = &mut dst.write()[0 .. dst_size];
|
||||
@ -402,7 +449,7 @@ impl BackendScale<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendMse<f32> for Native {
|
||||
impl BackendMse<f32> for Native<f32> {
|
||||
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: f32) {
|
||||
let a_size = a.shape().size();
|
||||
let b_size = b.shape().size();
|
||||
@ -440,7 +487,7 @@ impl BackendMse<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendAxpy<f32> for Native {
|
||||
impl BackendAxpy<f32> for Native<f32> {
|
||||
default fn axpy(&self, dst: &mut Self::Tensor, scale: f32, a: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -455,7 +502,7 @@ impl BackendAxpy<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendAxpys<f32> for Native {
|
||||
impl BackendAxpys<f32> for Native<f32> {
|
||||
fn axpys(&self, dst: &mut Self::Tensor, scale: f32, a: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -470,7 +517,7 @@ impl BackendAxpys<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendAdd<f32> for Native {
|
||||
impl BackendAdd<f32> for Native<f32> {
|
||||
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -485,7 +532,7 @@ impl BackendAdd<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendSub<f32> for Native {
|
||||
impl BackendSub<f32> for Native<f32> {
|
||||
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
let a_size = a.shape().size();
|
||||
let b_size = b.shape().size();
|
||||
@ -505,7 +552,7 @@ impl BackendSub<f32> for Native {
|
||||
|
||||
}
|
||||
|
||||
impl BackendMul<f32> for Native {
|
||||
impl BackendMul<f32> for Native<f32> {
|
||||
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -521,7 +568,7 @@ impl BackendMul<f32> for Native {
|
||||
}
|
||||
|
||||
|
||||
impl BackendCopy<f32> for Native {
|
||||
impl BackendCopy<f32> for Native<f32> {
|
||||
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
let size = dst.shape().size();
|
||||
|
||||
@ -536,7 +583,7 @@ impl BackendCopy<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendMaximum<f32> for Native {
|
||||
impl BackendMaximum<f32> for Native<f32> {
|
||||
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -552,7 +599,7 @@ impl BackendMaximum<f32> for Native {
|
||||
}
|
||||
|
||||
|
||||
impl BackendAdam<f32> for Native {
|
||||
impl BackendAdam<f32> for Native<f32> {
|
||||
fn adam_p(&self, dst: &mut Self::Tensor, lr: f32, moms: &Self::Tensor, vels: &Self::Tensor, eps: f32) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
@ -569,7 +616,7 @@ impl BackendAdam<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendSoftmax<f32> for Native {
|
||||
impl BackendSoftmax<f32> for Native<f32> {
|
||||
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor) {
|
||||
let y_shape = y.shape();
|
||||
let x_shape = x.shape();
|
||||
@ -619,7 +666,7 @@ impl BackendSoftmax<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendConv2d<f32> for Native {
|
||||
impl BackendConv2d<f32> for Native<f32> {
|
||||
type Context = ();
|
||||
|
||||
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
@ -742,7 +789,7 @@ impl BackendConv2d<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendMaxPool2d<f32> for Native {
|
||||
impl BackendMaxPool2d<f32> for Native<f32> {
|
||||
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
let x_shape = &x.shape().as_slice()[0..4];
|
||||
let y_shape = &y.shape().as_slice()[0..4];
|
||||
@ -845,7 +892,7 @@ impl BackendMaxPool2d<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendAvgPool2d<f32> for Native {
|
||||
impl BackendAvgPool2d<f32> for Native<f32> {
|
||||
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
let x_shape = &x.shape().as_slice()[0..4];
|
||||
let y_shape = &y.shape().as_slice()[0..4];
|
||||
@ -899,7 +946,7 @@ impl BackendAvgPool2d<f32> for Native {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendPaddingCopy2d<f32> for Native {
|
||||
impl BackendPaddingCopy2d<f32> for Native<f32> {
|
||||
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
|
||||
let y_shape = &y.shape().as_slice()[0..4];
|
||||
let x_shape = &x.shape().as_slice()[0..4];
|
||||
@ -964,16 +1011,16 @@ impl BackendPaddingCopy2d<f32> for Native {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::*;
|
||||
use super::{Native, NativeTensorF32};
|
||||
use super::{Native, NativeTensor};
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
#[test]
|
||||
fn test_copy_with_padding2d() {
|
||||
let bac = Native;
|
||||
let mut a1 = NativeTensorF32::new((1, 1, 3, 3));
|
||||
let mut b1 = NativeTensorF32::new((1, 1, 5, 5));
|
||||
let mut a2 = NativeTensorF32::new((1, 1, 5, 5));
|
||||
let mut b2 = NativeTensorF32::new((1, 1, 3, 3));
|
||||
let bac: Native<f32> = Default::default();
|
||||
let mut a1 = NativeTensor::new((1, 1, 3, 3));
|
||||
let mut b1 = NativeTensor::new((1, 1, 5, 5));
|
||||
let mut a2 = NativeTensor::new((1, 1, 5, 5));
|
||||
let mut b2 = NativeTensor::new((1, 1, 3, 3));
|
||||
|
||||
bac.load_tensor_u8(&mut a1, &[
|
||||
1, 2, 3,
|
||||
@ -1014,9 +1061,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_softmax() {
|
||||
let bac = Native;
|
||||
let mut a = NativeTensorF32::new((3, 3));
|
||||
let mut b = NativeTensorF32::new((3, 3));
|
||||
let bac: Native<f32> = Default::default();
|
||||
let mut a = NativeTensor::new((3, 3));
|
||||
let mut b = NativeTensor::new((3, 3));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[
|
||||
1,2,3,
|
||||
@ -1037,10 +1084,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_matmul() {
|
||||
let bac = Native;
|
||||
let mut a = NativeTensorF32::new((2, 3));
|
||||
let mut b = NativeTensorF32::new((3, 4));
|
||||
let mut c = NativeTensorF32::new((2, 4));
|
||||
let bac: Native<f32> = Default::default();
|
||||
let mut a = NativeTensor::new((2, 3));
|
||||
let mut b = NativeTensor::new((3, 4));
|
||||
let mut c = NativeTensor::new((2, 4));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[
|
||||
1,2,3,
|
||||
@ -1065,10 +1112,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_matmul_nt() {
|
||||
let bac = Native;
|
||||
let mut a = NativeTensorF32::new((2, 3));
|
||||
let mut b = NativeTensorF32::new((4, 3));
|
||||
let mut c = NativeTensorF32::new((2, 4));
|
||||
let bac: Native<f32> = Default::default();
|
||||
let mut a = NativeTensor::new((2, 3));
|
||||
let mut b = NativeTensor::new((4, 3));
|
||||
let mut c = NativeTensor::new((2, 4));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[
|
||||
1,2,3,
|
||||
@ -1095,10 +1142,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_matmul_tn() {
|
||||
let bac = Native;
|
||||
let mut a = NativeTensorF32::new((8, 5));
|
||||
let mut b = NativeTensorF32::new((8, 3));
|
||||
let mut c = NativeTensorF32::new((5, 3));
|
||||
let bac: Native<f32> = Default::default();
|
||||
let mut a = NativeTensor::new((8, 5));
|
||||
let mut b = NativeTensor::new((8, 3));
|
||||
let mut c = NativeTensor::new((5, 3));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[
|
||||
0, 1, 2, 3, 4,
|
||||
@ -1138,10 +1185,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_axpy() {
|
||||
let bac = Native;
|
||||
let bac: Native<f32> = Default::default();
|
||||
|
||||
let mut a = NativeTensorF32::new((2, 2));
|
||||
let mut b = NativeTensorF32::new((2, 2));
|
||||
let mut a = NativeTensor::new((2, 2));
|
||||
let mut b = NativeTensor::new((2, 2));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[1, 2, 3, 4]);
|
||||
bac.load_tensor_u8(&mut b, &[1, 2, 3, 4]);
|
||||
@ -1155,10 +1202,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let bac = Native;
|
||||
let bac: Native<f32> = Default::default();
|
||||
|
||||
let mut a = NativeTensorF32::new((2, 2));
|
||||
let mut b = NativeTensorF32::new((2, 2));
|
||||
let mut a = NativeTensor::new((2, 2));
|
||||
let mut b = NativeTensor::new((2, 2));
|
||||
|
||||
bac.load_tensor_u8(&mut a, &[1, 2, 3, 4]);
|
||||
bac.load_tensor_u8(&mut b, &[1, 2, 3, 4]);
|
||||
|
Loading…
Reference in New Issue
Block a user