From 282d0c4716119f2d391d4c3757486c1d101c1a6c Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Sat, 5 Mar 2022 18:30:23 +0400 Subject: [PATCH] Fix fmt and some warnings --- .drone.yml | 14 +++ build.rs | 2 +- src/dim.rs | 40 ++++---- src/executor.rs | 7 +- src/gemm.rs | 186 +++++++++++++++++++++----------------- src/kernel/avx/l1s.rs | 2 +- src/kernel/avx/l3s.rs | 25 +++-- src/kernel/avx/mod.rs | 32 ++++--- src/kernel/generic/l3s.rs | 11 +-- src/kernel/generic/mod.rs | 41 ++++----- src/kernel/mod.rs | 18 ++-- src/kernel/sse/fma.rs | 2 + src/kernel/sse/l3d.rs | 2 + src/kernel/sse/l3s.rs | 24 ++--- src/kernel/sse/mod.rs | 31 ++++--- src/lib.rs | 9 +- src/matrix.rs | 96 ++++++++++---------- src/sgemm.rs | 14 +-- src/test/l3s.rs | 68 ++++++++++---- src/test/mod.rs | 2 +- 20 files changed, 339 insertions(+), 287 deletions(-) create mode 100644 .drone.yml diff --git a/.drone.yml b/.drone.yml new file mode 100644 index 0000000..484f2d5 --- /dev/null +++ b/.drone.yml @@ -0,0 +1,14 @@ +kind: pipeline +name: default + +steps: +- name: build + image: rust + commands: + - cargo build --verbose --all + +- name: fmt-check + image: rust + commands: + - rustup component add rustfmt + - cargo fmt --all -- --check diff --git a/build.rs b/build.rs index 7a805fb..7ba9761 100644 --- a/build.rs +++ b/build.rs @@ -4,4 +4,4 @@ fn main() { // and run this. println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rustc-link-lib={}=openblas", "dylib"); -} \ No newline at end of file +} diff --git a/src/dim.rs b/src/dim.rs index 36a9c4b..449c14a 100644 --- a/src/dim.rs +++ b/src/dim.rs @@ -2,49 +2,49 @@ pub trait Dim { const DIM: usize; } -pub struct A1; -impl Dim for A1 { +pub struct A1; +impl Dim for A1 { const DIM: usize = 1; } -pub struct A2; -impl Dim for A2 { +pub struct A2; +impl Dim for A2 { const DIM: usize = 2; } -pub struct A3; -impl Dim for A3 { +pub struct A3; +impl Dim for A3 { const DIM: usize = 3; } -pub struct A4; -impl Dim for A4 { +pub struct A4; +impl Dim for A4 { const DIM: usize = 4; } -pub struct A5; -impl Dim for A5 { +pub struct A5; +impl Dim for A5 { const DIM: usize = 5; } -pub struct A6; -impl Dim for A6 { +pub struct A6; +impl Dim for A6 { const DIM: usize = 6; } -pub struct A7; -impl Dim for A7 { +pub struct A7; +impl Dim for A7 { const DIM: usize = 7; } -pub struct A8; -impl Dim for A8 { +pub struct A8; +impl Dim for A8 { const DIM: usize = 8; } -pub struct A16; -impl Dim for A16 { +pub struct A16; +impl Dim for A16 { const DIM: usize = 16; } /* macro_rules! gen { (($name: ident, $val: lit)) => { - + }; } @@ -72,4 +72,4 @@ gen! { (A28, 28), (A30, 30), (A32, 32), -}*/ \ No newline at end of file +}*/ diff --git a/src/executor.rs b/src/executor.rs index 9719e41..7822897 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -42,19 +42,18 @@ impl Executor for RayonExecutor { (start..end) .into_par_iter() .map(|x| x * step) - .for_each(|x|f(x)); + .for_each(|x| f(x)); } } - pub struct ThreadPoolExecutor { - thread_pool: threadpool::ThreadPool + thread_pool: threadpool::ThreadPool, } impl ThreadPoolExecutor { pub fn new() -> ThreadPoolExecutor { ThreadPoolExecutor { - thread_pool: threadpool::Builder::new().build() + thread_pool: threadpool::Builder::new().build(), } } } diff --git a/src/gemm.rs b/src/gemm.rs index 77ecd92..fcaa029 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1,10 +1,12 @@ use crate::aligned_alloc; -use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix}; -use crate::kernel::params::single::{NC, MC, KC}; -use crate::kernel; -use crate::kernel::GemmKernel; use crate::dim::Dim; use crate::executor::Executor; +use crate::kernel; +use crate::kernel::params::single::{KC, MC, NC}; +use crate::kernel::GemmKernel; +use crate::matrix::{ + ConstMatrix, ConstTransposedMatrix, Matrix, MatrixMut, MutMatrix, MutTransposedMatrix, Number, +}; pub unsafe fn gemm( e: &E, @@ -22,70 +24,109 @@ pub unsafe fn gemm( beta: F, c: *mut F, ldc: usize, -) -where E: Executor, - F: Number, - MR: Dim, NR: Dim, - K: GemmKernel, +) where + E: Executor, + F: Number, + MR: Dim, + NR: Dim, + K: GemmKernel, { match (transa, transb, transc) { (false, false, false) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstMatrix::new(a, lda), ConstMatrix::new(b, ldb), beta, - MutMatrix::new(c, ldc)), + MutMatrix::new(c, ldc), + ), (false, false, true) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstMatrix::new(a, lda), ConstMatrix::new(b, ldb), beta, - MutTransposedMatrix::new(c, ldc)), - + MutTransposedMatrix::new(c, ldc), + ), + (false, true, false) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstMatrix::new(a, lda), ConstTransposedMatrix::new(b, ldb), beta, - MutMatrix::new(c, ldc)), + MutMatrix::new(c, ldc), + ), (false, true, true) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstMatrix::new(a, lda), ConstTransposedMatrix::new(b, ldb), beta, - MutTransposedMatrix::new(c, ldc)), + MutTransposedMatrix::new(c, ldc), + ), (true, false, false) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstTransposedMatrix::new(a, lda), ConstMatrix::new(b, ldb), beta, - MutMatrix::new(c, ldc)), - + MutMatrix::new(c, ldc), + ), + (true, false, true) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstTransposedMatrix::new(a, lda), ConstMatrix::new(b, ldb), beta, - MutTransposedMatrix::new(c, ldc)), + MutTransposedMatrix::new(c, ldc), + ), - (true, true, false) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(b, ldb), beta, - MutMatrix::new(c, ldc)), + MutMatrix::new(c, ldc), + ), - (true, true, true) => gemm_template::( - e, m, n, k, alpha, + e, + m, + n, + k, + alpha, ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(b, ldb), beta, - MutTransposedMatrix::new(c, ldc)), + MutTransposedMatrix::new(c, ldc), + ), } } @@ -98,15 +139,16 @@ unsafe fn gemm_template( a: A, b: B, beta: F, - c: C -) -where E: Executor, - F: Number, - MR: Dim, NR: Dim, - K: GemmKernel, - A: Matrix, - B: Matrix, - C: MatrixMut, + c: C, +) where + E: Executor, + F: Number, + MR: Dim, + NR: Dim, + K: GemmKernel, + A: Matrix, + B: Matrix, + C: MatrixMut, { let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::()); let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::()); @@ -122,7 +164,7 @@ where E: Executor, let pb = MutMatrix::new(packed_b.ptr::(), p_b); inner_kernel::( - e, + e, i_b, j_b, p_b, @@ -133,7 +175,7 @@ where E: Executor, c.sub(j, i), pa, pb, - i == 0 + i == 0, ); } } @@ -141,7 +183,7 @@ where E: Executor, } // -// | MR | +// | MR | // +-----------------------------+----+ // | : | | // | TL : | TR | @@ -172,16 +214,15 @@ unsafe fn inner_kernel( pa: MutMatrix, pb: MutMatrix, first_time: bool, -) - where E: Executor, - F: Number, - MR: Dim, - NR: Dim, - K: kernel::GemmKernel, - A: Matrix, - B: Matrix, - C: MatrixMut, - +) where + E: Executor, + F: Number, + MR: Dim, + NR: Dim, + K: kernel::GemmKernel, + A: Matrix, + B: Matrix, + C: MatrixMut, { let n_left = n % NR::DIM; let n_main = n - n_left; @@ -190,33 +231,24 @@ unsafe fn inner_kernel( let m_main = m - m_left; if first_time { - e.execute(0, n_main, NR::DIM, move |j| - K::pack_row_b(b, pb)); + e.execute(0, n_main, NR::DIM, move |_j| K::pack_row_b(b, pb)); } - e.execute(0, m_main, MR::DIM, move |i| - K::pack_row_a(a.sub_col(i), pa.sub_row(i))); + e.execute(0, m_main, MR::DIM, move |i| { + K::pack_row_a(a.sub_col(i), pa.sub_row(i)) + }); e.synchronize(); - + e.execute(0, n_main, NR::DIM, move |j| { // Section TL for i in (0..m_main).step_by(MR::DIM) { - K::main_tl(alpha, - pa.sub_row(i), - pb.sub_row(j), - beta, - c.sub(j, i)); + K::main_tl(alpha, pa.sub_row(i), pb.sub_row(j), beta, c.sub(j, i)); } // Section TR for i in m_main..m { - K::sup_tr( - alpha, - a.sub_col(i), - pb.sub_row(j), - beta, - c.sub(j, i)); + K::sup_tr(alpha, a.sub_col(i), pb.sub_row(j), beta, c.sub(j, i)); } }); @@ -225,27 +257,15 @@ unsafe fn inner_kernel( let i = ji % m_main; // Section BL - K::sup_bl( - alpha, - pa.sub_row(i), - b.sub_row(j), - beta, - c.sub(j, i) - ); + K::sup_bl(alpha, pa.sub_row(i), b.sub_row(j), beta, c.sub(j, i)); }); // Section BR for j in n_main..n { for i in m_main..m { - K::sup_br( - k, - alpha, - a.sub_col(i), - b.sub_row(j), - beta, - c.sub(j, i)) + K::sup_br(k, alpha, a.sub_col(i), b.sub_row(j), beta, c.sub(j, i)) } - }; + } e.synchronize(); -} \ No newline at end of file +} diff --git a/src/kernel/avx/l1s.rs b/src/kernel/avx/l1s.rs index e315ade..4e0fdbb 100644 --- a/src/kernel/avx/l1s.rs +++ b/src/kernel/avx/l1s.rs @@ -246,7 +246,7 @@ pub unsafe fn sdot( let mut acc1 = _mm256_setzero_ps(); let mut acc2 = _mm256_setzero_ps(); let mut acc3 = _mm256_setzero_ps(); - + for _ in 0..n / 32 { let x0 = _mm256_loadu_ps(x); let y0 = _mm256_loadu_ps(y); diff --git a/src/kernel/avx/l3s.rs b/src/kernel/avx/l3s.rs index b66759f..c630a50 100644 --- a/src/kernel/avx/l3s.rs +++ b/src/kernel/avx/l3s.rs @@ -1,16 +1,16 @@ use super::fma::fmadd_ps; use super::intrinsics::*; use crate::kernel::params::single::{MR, NR}; -use crate::matrix::{Matrix, MutMatrix, MatrixMut}; +use crate::matrix::{Matrix, MatrixMut, MutMatrix}; use crunchy::unroll; #[inline] pub(crate) unsafe fn sgemm_ukr_16x8>( k: usize, - alpha: f32, + _alpha: f32, pa: MutMatrix, pb: MutMatrix, - beta: f32, + _beta: f32, c: C, ) { let mut mt00 = _mm256_setzero_ps(); @@ -21,7 +21,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // let mut mt05 = _mm256_setzero_ps(); // let mut mt06 = _mm256_setzero_ps(); // let mut mt07 = _mm256_setzero_ps(); - + let mut mt10 = _mm256_setzero_ps(); let mut mt11 = _mm256_setzero_ps(); let mut mt12 = _mm256_setzero_ps(); @@ -30,7 +30,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // let mut mt15 = _mm256_setzero_ps(); // let mut mt16 = _mm256_setzero_ps(); // let mut mt17 = _mm256_setzero_ps(); - + let mut pa = pa.ptr(); let mut pb = pb.ptr(); @@ -53,7 +53,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // let b5 = _mm256_broadcast_ss(&*pb.add(5)); // let b6 = _mm256_broadcast_ss(&*pb.add(6)); // let b7 = _mm256_broadcast_ss(&*pb.add(7)); - + mt00 = fmadd_ps(a0, b0, mt00); mt01 = fmadd_ps(a0, b1, mt01); mt02 = fmadd_ps(a0, b2, mt02); @@ -75,7 +75,6 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( pb = pb.add(BATCH * NR); } - for _ in k_main..k { let a0 = _mm256_load_ps(pa); let a1 = _mm256_load_ps(pa.add(8)); @@ -88,7 +87,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // let b5 = _mm256_broadcast_ss(&*pb.add(5)); // let b6 = _mm256_broadcast_ss(&*pb.add(6)); // let b7 = _mm256_broadcast_ss(&*pb.add(7)); - + mt00 = fmadd_ps(a0, b0, mt00); mt01 = fmadd_ps(a0, b1, mt01); mt02 = fmadd_ps(a0, b2, mt02); @@ -160,7 +159,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05); // // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06); // // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07); - + // mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10); // mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11); // mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12); @@ -179,7 +178,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // _mm256_storeu_ps(ccol5, mt05); // _mm256_storeu_ps(ccol6, mt06); // _mm256_storeu_ps(ccol7, mt07); - + _mm256_storeu_ps(ccol0.add(8), mt10); _mm256_storeu_ps(ccol1.add(8), mt11); _mm256_storeu_ps(ccol2.add(8), mt12); @@ -192,10 +191,10 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( pub(crate) unsafe fn sgemm_sup_16x1, C: MatrixMut>( k: usize, - alpha: f32, + _alpha: f32, pa: MutMatrix, b: B, - beta: f32, + _beta: f32, c: C, ) { let mut mt0 = _mm256_setzero_ps(); @@ -250,4 +249,4 @@ pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut pa = pa.add(16); a = a.add(lda); } -} \ No newline at end of file +} diff --git a/src/kernel/avx/mod.rs b/src/kernel/avx/mod.rs index 723b550..10397d5 100644 --- a/src/kernel/avx/mod.rs +++ b/src/kernel/avx/mod.rs @@ -6,15 +6,16 @@ pub mod l1s; pub mod l3d; pub mod l3s; -use core::marker::PhantomData; -use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut}; -use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; use crate::dim::*; +use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr}; +use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number}; +use core::marker::PhantomData; pub struct AvxKernel(PhantomData); -impl GemmKernelSupNr for AvxKernel - where I: GemmKernelSupNr +impl GemmKernelSupNr for AvxKernel +where + I: GemmKernelSupNr, { #[inline] unsafe fn sup_tr, C: MatrixMut>( @@ -26,10 +27,11 @@ impl GemmKernelSupNr for AvxKernel ) { I::sup_tr(alpha, a, pb, beta, c); } -} +} -impl GemmKernelSupMr for AvxKernel - where I: GemmKernelSupMr +impl GemmKernelSupMr for AvxKernel +where + I: GemmKernelSupMr, { #[inline] unsafe fn sup_bl, C: MatrixMut>( @@ -43,8 +45,9 @@ impl GemmKernelSupMr for AvxKernel } } -impl GemmKernelSup for AvxKernel - where I: GemmKernelSup +impl GemmKernelSup for AvxKernel +where + I: GemmKernelSup, { #[inline] unsafe fn sup_br, B: Matrix, C: MatrixMut>( @@ -59,12 +62,13 @@ impl GemmKernelSup for AvxKernel } } -impl GemmKernel for AvxKernel - where I: GemmKernel +impl GemmKernel for AvxKernel +where + I: GemmKernel, { #[inline] unsafe fn pack_row_a>(a: A, pa: MutMatrix) { - if a.is_transposed() { + if a.is_transposed() { I::pack_row_a(a, pa); } else { self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut()); @@ -86,4 +90,4 @@ impl GemmKernel for AvxKernel ) { self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c); } -} \ No newline at end of file +} diff --git a/src/kernel/generic/l3s.rs b/src/kernel/generic/l3s.rs index 6f33b89..85495bd 100644 --- a/src/kernel/generic/l3s.rs +++ b/src/kernel/generic/l3s.rs @@ -1,13 +1,12 @@ use crate::kernel::params::single::NR; -use crate::matrix::{MutMatrix, Matrix, MatrixMut}; - +use crate::matrix::{Matrix, MatrixMut, MutMatrix}; pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( k: usize, - alpha: f32, + _alpha: f32, a: A, pb: MutMatrix, - beta: f32, + _beta: f32, c: C, ) { let mut c0 = 0.0f32; @@ -46,7 +45,7 @@ pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( // // c5 *= alpha; // // c6 *= alpha; // // c7 *= alpha; - + let ccol0 = c.ptr_mut(); let ccol1 = c.row_mut(1); let ccol2 = c.row_mut(2); @@ -139,4 +138,4 @@ pub(crate) unsafe fn sgemm_pa_t(k: usize, a: *const f32, lda: usize, pa: *mut f3 a = a.add(1); pa = pa.add(MR); } -} \ No newline at end of file +} diff --git a/src/kernel/generic/mod.rs b/src/kernel/generic/mod.rs index 0f82c4f..516325e 100644 --- a/src/kernel/generic/mod.rs +++ b/src/kernel/generic/mod.rs @@ -1,9 +1,9 @@ pub mod l3d; pub mod l3s; -use crate::matrix::{Matrix, MatrixMut, MutMatrix}; -use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; use crate::dim::*; +use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr}; +use crate::matrix::{Matrix, MatrixMut, MutMatrix}; pub struct GenericKernel; @@ -16,24 +16,18 @@ impl GemmKernelSupNr for GenericKernel { beta: f32, c: C, ) { - self::l3s::sgemm_sup_1x8( - pb.stride, - alpha, - a, - pb, - beta, - c) + self::l3s::sgemm_sup_1x8(pb.stride, alpha, a, pb, beta, c) } -} +} impl GemmKernelSupMr for GenericKernel { #[inline] unsafe fn sup_bl, C: MatrixMut>( - alpha: f32, - pa: MutMatrix, - b: B, - beta: f32, - c: C, + _alpha: f32, + _pa: MutMatrix, + _b: B, + _beta: f32, + _c: C, ) { unimplemented!() } @@ -66,10 +60,9 @@ impl GemmKernelSup for GenericKernel { } impl GemmKernel for GenericKernel { - #[inline] unsafe fn pack_row_a>(a: A, pa: MutMatrix) { - if a.is_transposed() { + if a.is_transposed() { self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut()); } else { unimplemented!() @@ -78,7 +71,7 @@ impl GemmKernel for GenericKernel { #[inline] unsafe fn pack_row_b>(b: B, pb: MutMatrix) { - if b.is_transposed() { + if b.is_transposed() { self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut()); } else { self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut()); @@ -87,12 +80,12 @@ impl GemmKernel for GenericKernel { #[inline] unsafe fn main_tl>( - alpha: f32, - pa: MutMatrix, - pb: MutMatrix, - beta: f32, - c: C, + _alpha: f32, + _pa: MutMatrix, + _pb: MutMatrix, + _beta: f32, + _c: C, ) { unimplemented!() } -} \ No newline at end of file +} diff --git a/src/kernel/mod.rs b/src/kernel/mod.rs index 0c9615d..25f884a 100644 --- a/src/kernel/mod.rs +++ b/src/kernel/mod.rs @@ -6,14 +6,14 @@ pub mod avx; pub mod generic; pub mod sse; -use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix}; use crate::dim::Dim; +use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number}; pub mod params { pub mod single { pub const MC: usize = 128; pub const KC: usize = 256; - pub const NC: usize = 8*1024; + pub const NC: usize = 8 * 1024; pub const MR: usize = 16; pub const NR: usize = 5; } @@ -70,15 +70,11 @@ pub trait GemmKernelSupNr { ); } -pub trait GemmKernel: GemmKernelSupMr + GemmKernelSupNr + GemmKernelSup { +pub trait GemmKernel: + GemmKernelSupMr + GemmKernelSupNr + GemmKernelSup +{ unsafe fn pack_row_a>(a: A, pa: MutMatrix); unsafe fn pack_row_b>(b: B, pb: MutMatrix); - unsafe fn main_tl>( - alpha: F, - pa: MutMatrix, - pb: MutMatrix, - beta: F, - c: C, - ); -} \ No newline at end of file + unsafe fn main_tl>(alpha: F, pa: MutMatrix, pb: MutMatrix, beta: F, c: C); +} diff --git a/src/kernel/sse/fma.rs b/src/kernel/sse/fma.rs index 793ccda..3705670 100644 --- a/src/kernel/sse/fma.rs +++ b/src/kernel/sse/fma.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::intrinsics::*; #[cfg(target_feature = "fma")] diff --git a/src/kernel/sse/l3d.rs b/src/kernel/sse/l3d.rs index 82e7238..8a540ba 100644 --- a/src/kernel/sse/l3d.rs +++ b/src/kernel/sse/l3d.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::fma::fmadd_pd; use super::intrinsics::*; diff --git a/src/kernel/sse/l3s.rs b/src/kernel/sse/l3s.rs index 8ec2471..ac9e54f 100644 --- a/src/kernel/sse/l3s.rs +++ b/src/kernel/sse/l3s.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::fma::fmadd_ps; use super::intrinsics::*; use crate::kernel::params::single::NR; @@ -5,10 +7,10 @@ use crate::matrix::{Matrix, MatrixMut, MutMatrix}; pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( k: usize, - alpha: f32, + _alpha: f32, a: A, pb: MutMatrix, - beta: f32, + _beta: f32, c: C, ) { let mut c0_3 = _mm_setzero_ps(); @@ -23,7 +25,7 @@ pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( for _ in 0..k { let a0 = *a.ptr(); let a0_simd = _mm_broadcast_ss(&*a.ptr()); - + c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3); c4 += *pb.col(4) * a0; // c5 += *pb.add(5) * a0; @@ -42,7 +44,7 @@ pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( // c5 *= alpha; // c6 *= alpha; // c7 *= alpha; - + let ccol0_3 = c.ptr_mut(); let ccol4 = c.row_mut(4); // let ccol5 = c.add(ldc * 5); @@ -69,11 +71,11 @@ pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( pub(crate) unsafe fn sgemm_sup_1x8_t( k: usize, - alpha: f32, + _alpha: f32, a: *const f32, - lda: usize, + _lda: usize, pb: *const f32, - beta: f32, + _beta: f32, c: *mut f32, ldc: usize, ) { @@ -88,8 +90,8 @@ pub(crate) unsafe fn sgemm_sup_1x8_t( for _ in 0..k { let a0 = *a; - let a0_simd = _mm_set_ps1(a0); - + let a0_simd = _mm_set_ps1(a0); + c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03); c4 += *pb.add(4) * a0; // c5 += *pb.add(5) * a0; @@ -108,7 +110,7 @@ pub(crate) unsafe fn sgemm_sup_1x8_t( // c5 *= alpha; // c6 *= alpha; // c7 *= alpha; - + // let ccol0 = c; // let ccol1 = c.add(ldc); // let ccol2 = c.add(ldc * 2); @@ -175,4 +177,4 @@ pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f pb = pb.add(NR); } -} \ No newline at end of file +} diff --git a/src/kernel/sse/mod.rs b/src/kernel/sse/mod.rs index 2e464c6..3993515 100644 --- a/src/kernel/sse/mod.rs +++ b/src/kernel/sse/mod.rs @@ -6,16 +6,16 @@ mod intrinsics; // pub mod l3d; pub mod l3s; -use core::marker::PhantomData; -use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut}; -use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; use crate::dim::*; - +use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr}; +use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number}; +use core::marker::PhantomData; pub struct SseKernel(PhantomData); -impl GemmKernelSupNr for SseKernel - where I: GemmKernelSupNr +impl GemmKernelSupNr for SseKernel +where + I: GemmKernelSupNr, { #[inline] unsafe fn sup_tr, C: MatrixMut>( @@ -27,10 +27,11 @@ impl GemmKernelSupNr for SseKernel ) { I::sup_tr(alpha, a, pb, beta, c); } -} +} -impl GemmKernelSupMr for SseKernel - where I: GemmKernelSupMr +impl GemmKernelSupMr for SseKernel +where + I: GemmKernelSupMr, { #[inline] unsafe fn sup_bl, C: MatrixMut>( @@ -44,8 +45,9 @@ impl GemmKernelSupMr for SseKernel } } -impl GemmKernelSup for SseKernel - where I: GemmKernelSup +impl GemmKernelSup for SseKernel +where + I: GemmKernelSup, { #[inline] unsafe fn sup_br, B: Matrix, C: MatrixMut>( @@ -60,8 +62,9 @@ impl GemmKernelSup for SseKernel } } -impl GemmKernel for SseKernel - where I: GemmKernel +impl GemmKernel for SseKernel +where + I: GemmKernel, { #[inline] unsafe fn pack_row_a>(a: A, pa: MutMatrix) { @@ -83,4 +86,4 @@ impl GemmKernel for SseKernel ) { I::main_tl(alpha, pa, pb, beta, c); } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 0a9b138..8f3cb1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,10 @@ -pub mod matrix; -pub mod gemm; -mod sgemm; mod aligned_alloc; -pub mod kernel; pub mod dim; pub mod executor; +pub mod gemm; +pub mod kernel; +pub mod matrix; +mod sgemm; #[cfg(test)] extern crate blas; @@ -13,5 +13,4 @@ extern crate openblas; #[cfg(test)] mod test; - pub use crate::sgemm::sgemm; diff --git a/src/matrix.rs b/src/matrix.rs index 4c91c14..a9581d0 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,5 +1,4 @@ - -pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {} +pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {} impl Number for f32 {} impl Number for f64 {} @@ -32,7 +31,7 @@ pub trait MatrixMut: Matrix { #[derive(Copy, Clone)] pub struct ConstMatrix { pub stride: usize, - pub ptr: *const F + pub ptr: *const F, } unsafe impl Send for ConstMatrix {} @@ -47,31 +46,31 @@ impl ConstMatrix { impl Matrix for ConstMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { - Self { - ptr: self.index(row, col), - stride: self.stride + Self { + ptr: self.index(row, col), + stride: self.stride, } } #[inline] unsafe fn sub_col(&self, col: usize) -> Self { - Self { - ptr: self.col(col), - stride: self.stride + Self { + ptr: self.col(col), + stride: self.stride, } } #[inline] unsafe fn sub_row(&self, row: usize) -> Self { - Self { - ptr: self.row(row), - stride: self.stride + Self { + ptr: self.row(row), + stride: self.stride, } } #[inline] - unsafe fn is_transposed(&self) -> bool { - false + unsafe fn is_transposed(&self) -> bool { + false } #[inline] @@ -128,7 +127,7 @@ impl Matrix for ConstMatrix { #[derive(Copy, Clone)] pub struct ConstTransposedMatrix { pub stride: usize, - pub ptr: *const F + pub ptr: *const F, } unsafe impl Send for ConstTransposedMatrix {} @@ -143,25 +142,25 @@ impl ConstTransposedMatrix { impl Matrix for ConstTransposedMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { - Self { - ptr: self.index(row, col), - stride: self.stride + Self { + ptr: self.index(row, col), + stride: self.stride, } } #[inline] unsafe fn sub_col(&self, col: usize) -> Self { - Self { - ptr: self.col(col), - stride: self.stride + Self { + ptr: self.col(col), + stride: self.stride, } } #[inline] unsafe fn sub_row(&self, row: usize) -> Self { - Self { - ptr: self.row(row), - stride: self.stride + Self { + ptr: self.row(row), + stride: self.stride, } } @@ -224,7 +223,7 @@ impl Matrix for ConstTransposedMatrix { #[derive(Copy, Clone)] pub struct MutMatrix { pub stride: usize, - pub ptr: *mut F + pub ptr: *mut F, } unsafe impl Send for MutMatrix {} @@ -239,38 +238,38 @@ impl MutMatrix { impl Matrix for MutMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { - Self { - ptr: self.index_mut(row, col), - stride: self.stride + Self { + ptr: self.index_mut(row, col), + stride: self.stride, } } #[inline] unsafe fn sub_col(&self, col: usize) -> Self { - Self { - ptr: self.col_mut(col), - stride: self.stride + Self { + ptr: self.col_mut(col), + stride: self.stride, } } #[inline] unsafe fn sub_row(&self, row: usize) -> Self { - Self { - ptr: self.row_mut(row), - stride: self.stride + Self { + ptr: self.row_mut(row), + stride: self.stride, } } #[inline] - unsafe fn is_transposed(&self) -> bool { - false + unsafe fn is_transposed(&self) -> bool { + false } #[inline] unsafe fn stride(&self) -> usize { self.stride } - + #[inline] unsafe fn get(&self, row: usize, col: usize) -> F { *self.ptr.add(row * self.stride + col) @@ -347,7 +346,7 @@ impl MatrixMut for MutMatrix { #[derive(Copy, Clone)] pub struct MutTransposedMatrix { pub stride: usize, - pub ptr: *mut F + pub ptr: *mut F, } unsafe impl Send for MutTransposedMatrix {} @@ -362,25 +361,25 @@ impl MutTransposedMatrix { impl Matrix for MutTransposedMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { - Self { - ptr: self.index_mut(row, col), - stride: self.stride + Self { + ptr: self.index_mut(row, col), + stride: self.stride, } } #[inline] unsafe fn sub_col(&self, col: usize) -> Self { - Self { - ptr: self.col_mut(col), - stride: self.stride + Self { + ptr: self.col_mut(col), + stride: self.stride, } } #[inline] unsafe fn sub_row(&self, row: usize) -> Self { - Self { - ptr: self.row_mut(row), - stride: self.stride + Self { + ptr: self.row_mut(row), + stride: self.stride, } } @@ -440,7 +439,6 @@ impl Matrix for MutTransposedMatrix { } } - impl MatrixMut for MutTransposedMatrix { #[inline] unsafe fn set(&self, row: usize, col: usize, value: F) { @@ -467,5 +465,3 @@ impl MatrixMut for MutTransposedMatrix { self.ptr } } - - diff --git a/src/sgemm.rs b/src/sgemm.rs index 8da14e2..78d3cb5 100644 --- a/src/sgemm.rs +++ b/src/sgemm.rs @@ -1,8 +1,7 @@ - -use crate::kernel::avx::AvxKernel; -use crate::kernel::generic::GenericKernel; use crate::dim::*; use crate::executor::Executor; +use crate::kernel::avx::AvxKernel; +use crate::kernel::generic::GenericKernel; pub unsafe fn sgemm( e: &E, @@ -22,9 +21,6 @@ pub unsafe fn sgemm( ldc: usize, ) { crate::gemm::gemm::, A16, A5>( - e, - transa, transb, transc, - m, n, k, alpha, - a, lda, b, ldb, - beta, c, ldc); -} \ No newline at end of file + e, transa, transb, transc, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + ); +} diff --git a/src/test/l3s.rs b/src/test/l3s.rs index 4651529..a070618 100644 --- a/src/test/l3s.rs +++ b/src/test/l3s.rs @@ -1,17 +1,15 @@ -use rand::Rng; - const M_LEN: usize = 17; const N_LEN: usize = 6; const K_LEN: usize = 4; fn make_matrices() -> (Vec, Vec, Vec, Vec) { - let mut rng = rand::thread_rng(); + let _rng = rand::thread_rng(); let (m, n, k) = (M_LEN, N_LEN, K_LEN); let mut a = vec![0.0; m * k]; let mut a_t = vec![0.0; m * k]; - + let mut b = vec![0.0; n * k]; let mut b_t = vec![0.0; n * k]; @@ -38,7 +36,6 @@ fn make_matrices() -> (Vec, Vec, Vec, Vec) { (a, a_t, b, b_t) } - unsafe fn sgemm_ref( m: usize, n: usize, @@ -55,9 +52,9 @@ unsafe fn sgemm_ref( for i_m in 0..m { for i_k in 0..k { let a_part = *a.add(i_m * lda + i_k); - + for i_n in 0..n { - *c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n); + *c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n); } } } @@ -75,7 +72,6 @@ unsafe fn sgemm_ref( // } } - #[test] fn test_sgemm_nn() { let (m, n, k) = (M_LEN, N_LEN, K_LEN); @@ -133,7 +129,6 @@ fn test_sgemm_nn() { ); } - for row in 0..N_LEN { for col in 0..M_LEN { let index = row * M_LEN + col; @@ -160,16 +155,24 @@ fn test_sgemm_nn() { let (a, b) = (c[index], cref[index]); - assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + assert!( + feq(a, b), + "a != b, a[{}, {}]={}, b[{}, {}]={}", + row, + col, + a, + row, + col, + b + ); } } - } #[test] fn test_sgemm_nt() { let (m, n, k) = (M_LEN, N_LEN, K_LEN); - let (a, _, b, b_t) = make_matrices(); + let (a, _, _b, b_t) = make_matrices(); let mut c = vec![0.0; m * n]; let mut cref = vec![0.0; m * n]; @@ -216,16 +219,24 @@ fn test_sgemm_nt() { for col in 0..M_LEN { let index = row * M_LEN + col; let (a, b) = (c[index], cref[index]); - assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + assert!( + feq(a, b), + "a != b, a[{}, {}]={}, b[{}, {}]={}", + row, + col, + a, + row, + col, + b + ); } } } - #[test] fn test_sgemm_tn() { let (m, n, k) = (M_LEN, N_LEN, K_LEN); - let (a, a_t, b, _) = make_matrices(); + let (_a, a_t, b, _) = make_matrices(); let mut c = vec![0.0; m * n]; let mut cref = vec![0.0; m * n]; @@ -272,16 +283,24 @@ fn test_sgemm_tn() { for col in 0..M_LEN { let index = row * M_LEN + col; let (a, b) = (c[index], cref[index]); - assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + assert!( + feq(a, b), + "a != b, a[{}, {}]={}, b[{}, {}]={}", + row, + col, + a, + row, + col, + b + ); } } } - #[test] fn test_sgemm_tt() { let (m, n, k) = (M_LEN, N_LEN, K_LEN); - let (a, a_t, b, b_t) = make_matrices(); + let (_a, a_t, _b, b_t) = make_matrices(); let mut c = vec![0.0; m * n]; let mut cref = vec![0.0; m * n]; @@ -328,7 +347,16 @@ fn test_sgemm_tt() { for col in 0..M_LEN { let index = row * M_LEN + col; let (a, b) = (c[index], cref[index]); - assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + assert!( + feq(a, b), + "a != b, a[{}, {}]={}, b[{}, {}]={}", + row, + col, + a, + row, + col, + b + ); } } } @@ -341,4 +369,4 @@ fn feq(a: f32, b: f32) -> bool { } else { (a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0 } -} \ No newline at end of file +} diff --git a/src/test/mod.rs b/src/test/mod.rs index 5b99a65..8039281 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1 +1 @@ -mod l3s; \ No newline at end of file +mod l3s; -- 2.45.2