Merge pull request 'Fix fmt and some warnings' (#1) from fix-fmt into master
Some checks reported errors
continuous-integration/drone/push Build is passing
continuous-integration/drone Build was killed

Reviewed-on: #1
This commit is contained in:
Andrey Tkachenko 2022-03-05 18:49:40 +04:00
commit f24e10a1fd
20 changed files with 339 additions and 287 deletions

14
.drone.yml Normal file
View File

@ -0,0 +1,14 @@
kind: pipeline
name: default
steps:
- name: build
image: rust
commands:
- cargo build --verbose --all
- name: fmt-check
image: rust
commands:
- rustup component add rustfmt
- cargo fmt --all -- --check

View File

@ -42,19 +42,18 @@ impl Executor for RayonExecutor {
(start..end) (start..end)
.into_par_iter() .into_par_iter()
.map(|x| x * step) .map(|x| x * step)
.for_each(|x|f(x)); .for_each(|x| f(x));
} }
} }
pub struct ThreadPoolExecutor { pub struct ThreadPoolExecutor {
thread_pool: threadpool::ThreadPool thread_pool: threadpool::ThreadPool,
} }
impl ThreadPoolExecutor { impl ThreadPoolExecutor {
pub fn new() -> ThreadPoolExecutor { pub fn new() -> ThreadPoolExecutor {
ThreadPoolExecutor { ThreadPoolExecutor {
thread_pool: threadpool::Builder::new().build() thread_pool: threadpool::Builder::new().build(),
} }
} }
} }

View File

@ -1,10 +1,12 @@
use crate::aligned_alloc; use crate::aligned_alloc;
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
use crate::kernel::params::single::{NC, MC, KC};
use crate::kernel;
use crate::kernel::GemmKernel;
use crate::dim::Dim; use crate::dim::Dim;
use crate::executor::Executor; use crate::executor::Executor;
use crate::kernel;
use crate::kernel::params::single::{KC, MC, NC};
use crate::kernel::GemmKernel;
use crate::matrix::{
ConstMatrix, ConstTransposedMatrix, Matrix, MatrixMut, MutMatrix, MutTransposedMatrix, Number,
};
pub unsafe fn gemm<E, F, K, MR, NR>( pub unsafe fn gemm<E, F, K, MR, NR>(
e: &E, e: &E,
@ -22,70 +24,109 @@ pub unsafe fn gemm<E, F, K, MR, NR>(
beta: F, beta: F,
c: *mut F, c: *mut F,
ldc: usize, ldc: usize,
) ) where
where E: Executor, E: Executor,
F: Number, F: Number,
MR: Dim, NR: Dim, MR: Dim,
NR: Dim,
K: GemmKernel<F, MR, NR>, K: GemmKernel<F, MR, NR>,
{ {
match (transa, transb, transc) { match (transa, transb, transc) {
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>( (false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstMatrix::new(a, lda), ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb), ConstMatrix::new(b, ldb),
beta, beta,
MutMatrix::new(c, ldc)), MutMatrix::new(c, ldc),
),
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>( (false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstMatrix::new(a, lda), ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb), ConstMatrix::new(b, ldb),
beta, beta,
MutTransposedMatrix::new(c, ldc)), MutTransposedMatrix::new(c, ldc),
),
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>( (false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstMatrix::new(a, lda), ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb), ConstTransposedMatrix::new(b, ldb),
beta, beta,
MutMatrix::new(c, ldc)), MutMatrix::new(c, ldc),
),
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>( (false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstMatrix::new(a, lda), ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb), ConstTransposedMatrix::new(b, ldb),
beta, beta,
MutTransposedMatrix::new(c, ldc)), MutTransposedMatrix::new(c, ldc),
),
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>( (true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb), ConstMatrix::new(b, ldb),
beta, beta,
MutMatrix::new(c, ldc)), MutMatrix::new(c, ldc),
),
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>( (true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb), ConstMatrix::new(b, ldb),
beta, beta,
MutTransposedMatrix::new(c, ldc)), MutTransposedMatrix::new(c, ldc),
),
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>( (true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb), ConstTransposedMatrix::new(b, ldb),
beta, beta,
MutMatrix::new(c, ldc)), MutMatrix::new(c, ldc),
),
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>( (true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha, e,
m,
n,
k,
alpha,
ConstTransposedMatrix::new(a, lda), ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb), ConstTransposedMatrix::new(b, ldb),
beta, beta,
MutTransposedMatrix::new(c, ldc)), MutTransposedMatrix::new(c, ldc),
),
} }
} }
@ -98,11 +139,12 @@ unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
a: A, a: A,
b: B, b: B,
beta: F, beta: F,
c: C c: C,
) ) where
where E: Executor, E: Executor,
F: Number, F: Number,
MR: Dim, NR: Dim, MR: Dim,
NR: Dim,
K: GemmKernel<F, MR, NR>, K: GemmKernel<F, MR, NR>,
A: Matrix<F>, A: Matrix<F>,
B: Matrix<F>, B: Matrix<F>,
@ -133,7 +175,7 @@ where E: Executor,
c.sub(j, i), c.sub(j, i),
pa, pa,
pb, pb,
i == 0 i == 0,
); );
} }
} }
@ -172,8 +214,8 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
pa: MutMatrix<F>, pa: MutMatrix<F>,
pb: MutMatrix<F>, pb: MutMatrix<F>,
first_time: bool, first_time: bool,
) ) where
where E: Executor, E: Executor,
F: Number, F: Number,
MR: Dim, MR: Dim,
NR: Dim, NR: Dim,
@ -181,7 +223,6 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
A: Matrix<F>, A: Matrix<F>,
B: Matrix<F>, B: Matrix<F>,
C: MatrixMut<F>, C: MatrixMut<F>,
{ {
let n_left = n % NR::DIM; let n_left = n % NR::DIM;
let n_main = n - n_left; let n_main = n - n_left;
@ -190,33 +231,24 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
let m_main = m - m_left; let m_main = m - m_left;
if first_time { if first_time {
e.execute(0, n_main, NR::DIM, move |j| e.execute(0, n_main, NR::DIM, move |_j| K::pack_row_b(b, pb));
K::pack_row_b(b, pb));
} }
e.execute(0, m_main, MR::DIM, move |i| e.execute(0, m_main, MR::DIM, move |i| {
K::pack_row_a(a.sub_col(i), pa.sub_row(i))); K::pack_row_a(a.sub_col(i), pa.sub_row(i))
});
e.synchronize(); e.synchronize();
e.execute(0, n_main, NR::DIM, move |j| { e.execute(0, n_main, NR::DIM, move |j| {
// Section TL // Section TL
for i in (0..m_main).step_by(MR::DIM) { for i in (0..m_main).step_by(MR::DIM) {
K::main_tl(alpha, K::main_tl(alpha, pa.sub_row(i), pb.sub_row(j), beta, c.sub(j, i));
pa.sub_row(i),
pb.sub_row(j),
beta,
c.sub(j, i));
} }
// Section TR // Section TR
for i in m_main..m { for i in m_main..m {
K::sup_tr( K::sup_tr(alpha, a.sub_col(i), pb.sub_row(j), beta, c.sub(j, i));
alpha,
a.sub_col(i),
pb.sub_row(j),
beta,
c.sub(j, i));
} }
}); });
@ -225,27 +257,15 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
let i = ji % m_main; let i = ji % m_main;
// Section BL // Section BL
K::sup_bl( K::sup_bl(alpha, pa.sub_row(i), b.sub_row(j), beta, c.sub(j, i));
alpha,
pa.sub_row(i),
b.sub_row(j),
beta,
c.sub(j, i)
);
}); });
// Section BR // Section BR
for j in n_main..n { for j in n_main..n {
for i in m_main..m { for i in m_main..m {
K::sup_br( K::sup_br(k, alpha, a.sub_col(i), b.sub_row(j), beta, c.sub(j, i))
k, }
alpha,
a.sub_col(i),
b.sub_row(j),
beta,
c.sub(j, i))
} }
};
e.synchronize(); e.synchronize();
} }

View File

@ -1,16 +1,16 @@
use super::fma::fmadd_ps; use super::fma::fmadd_ps;
use super::intrinsics::*; use super::intrinsics::*;
use crate::kernel::params::single::{MR, NR}; use crate::kernel::params::single::{MR, NR};
use crate::matrix::{Matrix, MutMatrix, MatrixMut}; use crate::matrix::{Matrix, MatrixMut, MutMatrix};
use crunchy::unroll; use crunchy::unroll;
#[inline] #[inline]
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>( pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
k: usize, k: usize,
alpha: f32, _alpha: f32,
pa: MutMatrix<f32>, pa: MutMatrix<f32>,
pb: MutMatrix<f32>, pb: MutMatrix<f32>,
beta: f32, _beta: f32,
c: C, c: C,
) { ) {
let mut mt00 = _mm256_setzero_ps(); let mut mt00 = _mm256_setzero_ps();
@ -75,7 +75,6 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
pb = pb.add(BATCH * NR); pb = pb.add(BATCH * NR);
} }
for _ in k_main..k { for _ in k_main..k {
let a0 = _mm256_load_ps(pa); let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8)); let a1 = _mm256_load_ps(pa.add(8));
@ -192,10 +191,10 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>( pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
k: usize, k: usize,
alpha: f32, _alpha: f32,
pa: MutMatrix<f32>, pa: MutMatrix<f32>,
b: B, b: B,
beta: f32, _beta: f32,
c: C, c: C,
) { ) {
let mut mt0 = _mm256_setzero_ps(); let mut mt0 = _mm256_setzero_ps();

View File

@ -6,15 +6,16 @@ pub mod l1s;
pub mod l3d; pub mod l3d;
pub mod l3s; pub mod l3s;
use core::marker::PhantomData;
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*; use crate::dim::*;
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
use core::marker::PhantomData;
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>); pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I> impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
where I: GemmKernelSupNr<f32, A5> where
I: GemmKernelSupNr<f32, A5>,
{ {
#[inline] #[inline]
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
@ -29,7 +30,8 @@ impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
} }
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I> impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
where I: GemmKernelSupMr<f32, A16> where
I: GemmKernelSupMr<f32, A16>,
{ {
#[inline] #[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
@ -44,7 +46,8 @@ impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
} }
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I> impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
where I: GemmKernelSup<f32> where
I: GemmKernelSup<f32>,
{ {
#[inline] #[inline]
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
@ -60,7 +63,8 @@ impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
} }
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I> impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
where I: GemmKernel<f32, A16, A5> where
I: GemmKernel<f32, A16, A5>,
{ {
#[inline] #[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) { unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {

View File

@ -1,13 +1,12 @@
use crate::kernel::params::single::NR; use crate::kernel::params::single::NR;
use crate::matrix::{MutMatrix, Matrix, MatrixMut}; use crate::matrix::{Matrix, MatrixMut, MutMatrix};
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>( pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
k: usize, k: usize,
alpha: f32, _alpha: f32,
a: A, a: A,
pb: MutMatrix<f32>, pb: MutMatrix<f32>,
beta: f32, _beta: f32,
c: C, c: C,
) { ) {
let mut c0 = 0.0f32; let mut c0 = 0.0f32;

View File

@ -1,9 +1,9 @@
pub mod l3d; pub mod l3d;
pub mod l3s; pub mod l3s;
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*; use crate::dim::*;
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
pub struct GenericKernel; pub struct GenericKernel;
@ -16,24 +16,18 @@ impl GemmKernelSupNr<f32, A5> for GenericKernel {
beta: f32, beta: f32,
c: C, c: C,
) { ) {
self::l3s::sgemm_sup_1x8( self::l3s::sgemm_sup_1x8(pb.stride, alpha, a, pb, beta, c)
pb.stride,
alpha,
a,
pb,
beta,
c)
} }
} }
impl GemmKernelSupMr<f32, A16> for GenericKernel { impl GemmKernelSupMr<f32, A16> for GenericKernel {
#[inline] #[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32, _alpha: f32,
pa: MutMatrix<f32>, _pa: MutMatrix<f32>,
b: B, _b: B,
beta: f32, _beta: f32,
c: C, _c: C,
) { ) {
unimplemented!() unimplemented!()
} }
@ -66,7 +60,6 @@ impl GemmKernelSup<f32> for GenericKernel {
} }
impl GemmKernel<f32, A16, A5> for GenericKernel { impl GemmKernel<f32, A16, A5> for GenericKernel {
#[inline] #[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) { unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
if a.is_transposed() { if a.is_transposed() {
@ -87,11 +80,11 @@ impl GemmKernel<f32, A16, A5> for GenericKernel {
#[inline] #[inline]
unsafe fn main_tl<C: MatrixMut<f32>>( unsafe fn main_tl<C: MatrixMut<f32>>(
alpha: f32, _alpha: f32,
pa: MutMatrix<f32>, _pa: MutMatrix<f32>,
pb: MutMatrix<f32>, _pb: MutMatrix<f32>,
beta: f32, _beta: f32,
c: C, _c: C,
) { ) {
unimplemented!() unimplemented!()
} }

View File

@ -6,14 +6,14 @@ pub mod avx;
pub mod generic; pub mod generic;
pub mod sse; pub mod sse;
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
use crate::dim::Dim; use crate::dim::Dim;
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
pub mod params { pub mod params {
pub mod single { pub mod single {
pub const MC: usize = 128; pub const MC: usize = 128;
pub const KC: usize = 256; pub const KC: usize = 256;
pub const NC: usize = 8*1024; pub const NC: usize = 8 * 1024;
pub const MR: usize = 16; pub const MR: usize = 16;
pub const NR: usize = 5; pub const NR: usize = 5;
} }
@ -70,15 +70,11 @@ pub trait GemmKernelSupNr<F: Number, NR: Dim> {
); );
} }
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> { pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>:
GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F>
{
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>); unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>); unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
unsafe fn main_tl<C: MatrixMut<F>>( unsafe fn main_tl<C: MatrixMut<F>>(alpha: F, pa: MutMatrix<F>, pb: MutMatrix<F>, beta: F, c: C);
alpha: F,
pa: MutMatrix<F>,
pb: MutMatrix<F>,
beta: F,
c: C,
);
} }

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
use super::intrinsics::*; use super::intrinsics::*;
#[cfg(target_feature = "fma")] #[cfg(target_feature = "fma")]

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
use super::fma::fmadd_pd; use super::fma::fmadd_pd;
use super::intrinsics::*; use super::intrinsics::*;

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
use super::fma::fmadd_ps; use super::fma::fmadd_ps;
use super::intrinsics::*; use super::intrinsics::*;
use crate::kernel::params::single::NR; use crate::kernel::params::single::NR;
@ -5,10 +7,10 @@ use crate::matrix::{Matrix, MatrixMut, MutMatrix};
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>( pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
k: usize, k: usize,
alpha: f32, _alpha: f32,
a: A, a: A,
pb: MutMatrix<f32>, pb: MutMatrix<f32>,
beta: f32, _beta: f32,
c: C, c: C,
) { ) {
let mut c0_3 = _mm_setzero_ps(); let mut c0_3 = _mm_setzero_ps();
@ -69,11 +71,11 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
pub(crate) unsafe fn sgemm_sup_1x8_t( pub(crate) unsafe fn sgemm_sup_1x8_t(
k: usize, k: usize,
alpha: f32, _alpha: f32,
a: *const f32, a: *const f32,
lda: usize, _lda: usize,
pb: *const f32, pb: *const f32,
beta: f32, _beta: f32,
c: *mut f32, c: *mut f32,
ldc: usize, ldc: usize,
) { ) {

View File

@ -6,16 +6,16 @@ mod intrinsics;
// pub mod l3d; // pub mod l3d;
pub mod l3s; pub mod l3s;
use core::marker::PhantomData;
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*; use crate::dim::*;
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
use core::marker::PhantomData;
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>); pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I> impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
where I: GemmKernelSupNr<f32, A5> where
I: GemmKernelSupNr<f32, A5>,
{ {
#[inline] #[inline]
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
@ -30,7 +30,8 @@ impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
} }
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I> impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
where I: GemmKernelSupMr<f32, A16> where
I: GemmKernelSupMr<f32, A16>,
{ {
#[inline] #[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
@ -45,7 +46,8 @@ impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
} }
impl<I> GemmKernelSup<f32> for SseKernel<f32, I> impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
where I: GemmKernelSup<f32> where
I: GemmKernelSup<f32>,
{ {
#[inline] #[inline]
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>( unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
@ -61,7 +63,8 @@ impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
} }
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I> impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
where I: GemmKernel<f32, A16, A5> where
I: GemmKernel<f32, A16, A5>,
{ {
#[inline] #[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) { unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {

View File

@ -1,10 +1,10 @@
pub mod matrix;
pub mod gemm;
mod sgemm;
mod aligned_alloc; mod aligned_alloc;
pub mod kernel;
pub mod dim; pub mod dim;
pub mod executor; pub mod executor;
pub mod gemm;
pub mod kernel;
pub mod matrix;
mod sgemm;
#[cfg(test)] #[cfg(test)]
extern crate blas; extern crate blas;
@ -13,5 +13,4 @@ extern crate openblas;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub use crate::sgemm::sgemm; pub use crate::sgemm::sgemm;

View File

@ -1,4 +1,3 @@
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {} pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
impl Number for f32 {} impl Number for f32 {}
@ -32,7 +31,7 @@ pub trait MatrixMut<F: Number>: Matrix<F> {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct ConstMatrix<F: Number> { pub struct ConstMatrix<F: Number> {
pub stride: usize, pub stride: usize,
pub ptr: *const F pub ptr: *const F,
} }
unsafe impl<F: Number> Send for ConstMatrix<F> {} unsafe impl<F: Number> Send for ConstMatrix<F> {}
@ -49,7 +48,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
unsafe fn sub(&self, row: usize, col: usize) -> Self { unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self { Self {
ptr: self.index(row, col), ptr: self.index(row, col),
stride: self.stride stride: self.stride,
} }
} }
@ -57,7 +56,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
unsafe fn sub_col(&self, col: usize) -> Self { unsafe fn sub_col(&self, col: usize) -> Self {
Self { Self {
ptr: self.col(col), ptr: self.col(col),
stride: self.stride stride: self.stride,
} }
} }
@ -65,7 +64,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
unsafe fn sub_row(&self, row: usize) -> Self { unsafe fn sub_row(&self, row: usize) -> Self {
Self { Self {
ptr: self.row(row), ptr: self.row(row),
stride: self.stride stride: self.stride,
} }
} }
@ -128,7 +127,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct ConstTransposedMatrix<F: Number> { pub struct ConstTransposedMatrix<F: Number> {
pub stride: usize, pub stride: usize,
pub ptr: *const F pub ptr: *const F,
} }
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {} unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
@ -145,7 +144,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
unsafe fn sub(&self, row: usize, col: usize) -> Self { unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self { Self {
ptr: self.index(row, col), ptr: self.index(row, col),
stride: self.stride stride: self.stride,
} }
} }
@ -153,7 +152,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
unsafe fn sub_col(&self, col: usize) -> Self { unsafe fn sub_col(&self, col: usize) -> Self {
Self { Self {
ptr: self.col(col), ptr: self.col(col),
stride: self.stride stride: self.stride,
} }
} }
@ -161,7 +160,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
unsafe fn sub_row(&self, row: usize) -> Self { unsafe fn sub_row(&self, row: usize) -> Self {
Self { Self {
ptr: self.row(row), ptr: self.row(row),
stride: self.stride stride: self.stride,
} }
} }
@ -224,7 +223,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MutMatrix<F: Number> { pub struct MutMatrix<F: Number> {
pub stride: usize, pub stride: usize,
pub ptr: *mut F pub ptr: *mut F,
} }
unsafe impl<F: Number> Send for MutMatrix<F> {} unsafe impl<F: Number> Send for MutMatrix<F> {}
@ -241,7 +240,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
unsafe fn sub(&self, row: usize, col: usize) -> Self { unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self { Self {
ptr: self.index_mut(row, col), ptr: self.index_mut(row, col),
stride: self.stride stride: self.stride,
} }
} }
@ -249,7 +248,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
unsafe fn sub_col(&self, col: usize) -> Self { unsafe fn sub_col(&self, col: usize) -> Self {
Self { Self {
ptr: self.col_mut(col), ptr: self.col_mut(col),
stride: self.stride stride: self.stride,
} }
} }
@ -257,7 +256,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
unsafe fn sub_row(&self, row: usize) -> Self { unsafe fn sub_row(&self, row: usize) -> Self {
Self { Self {
ptr: self.row_mut(row), ptr: self.row_mut(row),
stride: self.stride stride: self.stride,
} }
} }
@ -347,7 +346,7 @@ impl<F: Number> MatrixMut<F> for MutMatrix<F> {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MutTransposedMatrix<F: Number> { pub struct MutTransposedMatrix<F: Number> {
pub stride: usize, pub stride: usize,
pub ptr: *mut F pub ptr: *mut F,
} }
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {} unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
@ -364,7 +363,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
unsafe fn sub(&self, row: usize, col: usize) -> Self { unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self { Self {
ptr: self.index_mut(row, col), ptr: self.index_mut(row, col),
stride: self.stride stride: self.stride,
} }
} }
@ -372,7 +371,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
unsafe fn sub_col(&self, col: usize) -> Self { unsafe fn sub_col(&self, col: usize) -> Self {
Self { Self {
ptr: self.col_mut(col), ptr: self.col_mut(col),
stride: self.stride stride: self.stride,
} }
} }
@ -380,7 +379,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
unsafe fn sub_row(&self, row: usize) -> Self { unsafe fn sub_row(&self, row: usize) -> Self {
Self { Self {
ptr: self.row_mut(row), ptr: self.row_mut(row),
stride: self.stride stride: self.stride,
} }
} }
@ -440,7 +439,6 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
} }
} }
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> { impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
#[inline] #[inline]
unsafe fn set(&self, row: usize, col: usize, value: F) { unsafe fn set(&self, row: usize, col: usize, value: F) {
@ -467,5 +465,3 @@ impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
self.ptr self.ptr
} }
} }

View File

@ -1,8 +1,7 @@
use crate::kernel::avx::AvxKernel;
use crate::kernel::generic::GenericKernel;
use crate::dim::*; use crate::dim::*;
use crate::executor::Executor; use crate::executor::Executor;
use crate::kernel::avx::AvxKernel;
use crate::kernel::generic::GenericKernel;
pub unsafe fn sgemm<E: Executor>( pub unsafe fn sgemm<E: Executor>(
e: &E, e: &E,
@ -22,9 +21,6 @@ pub unsafe fn sgemm<E: Executor>(
ldc: usize, ldc: usize,
) { ) {
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>( crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
e, e, transa, transb, transc, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
transa, transb, transc, );
m, n, k, alpha,
a, lda, b, ldb,
beta, c, ldc);
} }

View File

@ -1,11 +1,9 @@
use rand::Rng;
const M_LEN: usize = 17; const M_LEN: usize = 17;
const N_LEN: usize = 6; const N_LEN: usize = 6;
const K_LEN: usize = 4; const K_LEN: usize = 4;
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) { fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let mut rng = rand::thread_rng(); let _rng = rand::thread_rng();
let (m, n, k) = (M_LEN, N_LEN, K_LEN); let (m, n, k) = (M_LEN, N_LEN, K_LEN);
@ -38,7 +36,6 @@ fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
(a, a_t, b, b_t) (a, a_t, b, b_t)
} }
unsafe fn sgemm_ref( unsafe fn sgemm_ref(
m: usize, m: usize,
n: usize, n: usize,
@ -75,7 +72,6 @@ unsafe fn sgemm_ref(
// } // }
} }
#[test] #[test]
fn test_sgemm_nn() { fn test_sgemm_nn() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN); let (m, n, k) = (M_LEN, N_LEN, K_LEN);
@ -133,7 +129,6 @@ fn test_sgemm_nn() {
); );
} }
for row in 0..N_LEN { for row in 0..N_LEN {
for col in 0..M_LEN { for col in 0..M_LEN {
let index = row * M_LEN + col; let index = row * M_LEN + col;
@ -160,16 +155,24 @@ fn test_sgemm_nn() {
let (a, b) = (c[index], cref[index]); let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); assert!(
feq(a, b),
"a != b, a[{}, {}]={}, b[{}, {}]={}",
row,
col,
a,
row,
col,
b
);
} }
} }
} }
#[test] #[test]
fn test_sgemm_nt() { fn test_sgemm_nt() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN); let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, _, b, b_t) = make_matrices(); let (a, _, _b, b_t) = make_matrices();
let mut c = vec![0.0; m * n]; let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n]; let mut cref = vec![0.0; m * n];
@ -216,16 +219,24 @@ fn test_sgemm_nt() {
for col in 0..M_LEN { for col in 0..M_LEN {
let index = row * M_LEN + col; let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]); let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); assert!(
feq(a, b),
"a != b, a[{}, {}]={}, b[{}, {}]={}",
row,
col,
a,
row,
col,
b
);
} }
} }
} }
#[test] #[test]
fn test_sgemm_tn() { fn test_sgemm_tn() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN); let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, a_t, b, _) = make_matrices(); let (_a, a_t, b, _) = make_matrices();
let mut c = vec![0.0; m * n]; let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n]; let mut cref = vec![0.0; m * n];
@ -272,16 +283,24 @@ fn test_sgemm_tn() {
for col in 0..M_LEN { for col in 0..M_LEN {
let index = row * M_LEN + col; let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]); let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); assert!(
feq(a, b),
"a != b, a[{}, {}]={}, b[{}, {}]={}",
row,
col,
a,
row,
col,
b
);
} }
} }
} }
#[test] #[test]
fn test_sgemm_tt() { fn test_sgemm_tt() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN); let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, a_t, b, b_t) = make_matrices(); let (_a, a_t, _b, b_t) = make_matrices();
let mut c = vec![0.0; m * n]; let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n]; let mut cref = vec![0.0; m * n];
@ -328,7 +347,16 @@ fn test_sgemm_tt() {
for col in 0..M_LEN { for col in 0..M_LEN {
let index = row * M_LEN + col; let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]); let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); assert!(
feq(a, b),
"a != b, a[{}, {}]={}, b[{}, {}]={}",
row,
col,
a,
row,
col,
b
);
} }
} }
} }