Fix fmt and some warnings
This commit is contained in:
parent
a699078f73
commit
282d0c4716
14
.drone.yml
Normal file
14
.drone.yml
Normal file
@ -0,0 +1,14 @@
|
||||
kind: pipeline
|
||||
name: default
|
||||
|
||||
steps:
|
||||
- name: build
|
||||
image: rust
|
||||
commands:
|
||||
- cargo build --verbose --all
|
||||
|
||||
- name: fmt-check
|
||||
image: rust
|
||||
commands:
|
||||
- rustup component add rustfmt
|
||||
- cargo fmt --all -- --check
|
@ -46,15 +46,14 @@ impl Executor for RayonExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct ThreadPoolExecutor {
|
||||
thread_pool: threadpool::ThreadPool
|
||||
thread_pool: threadpool::ThreadPool,
|
||||
}
|
||||
|
||||
impl ThreadPoolExecutor {
|
||||
pub fn new() -> ThreadPoolExecutor {
|
||||
ThreadPoolExecutor {
|
||||
thread_pool: threadpool::Builder::new().build()
|
||||
thread_pool: threadpool::Builder::new().build(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
146
src/gemm.rs
146
src/gemm.rs
@ -1,10 +1,12 @@
|
||||
use crate::aligned_alloc;
|
||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||
use crate::kernel::params::single::{NC, MC, KC};
|
||||
use crate::kernel;
|
||||
use crate::kernel::GemmKernel;
|
||||
use crate::dim::Dim;
|
||||
use crate::executor::Executor;
|
||||
use crate::kernel;
|
||||
use crate::kernel::params::single::{KC, MC, NC};
|
||||
use crate::kernel::GemmKernel;
|
||||
use crate::matrix::{
|
||||
ConstMatrix, ConstTransposedMatrix, Matrix, MatrixMut, MutMatrix, MutTransposedMatrix, Number,
|
||||
};
|
||||
|
||||
pub unsafe fn gemm<E, F, K, MR, NR>(
|
||||
e: &E,
|
||||
@ -22,70 +24,109 @@ pub unsafe fn gemm<E, F, K, MR, NR>(
|
||||
beta: F,
|
||||
c: *mut F,
|
||||
ldc: usize,
|
||||
)
|
||||
where E: Executor,
|
||||
) where
|
||||
E: Executor,
|
||||
F: Number,
|
||||
MR: Dim, NR: Dim,
|
||||
MR: Dim,
|
||||
NR: Dim,
|
||||
K: GemmKernel<F, MR, NR>,
|
||||
{
|
||||
match (transa, transb, transc) {
|
||||
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
MutMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
MutTransposedMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
MutMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
MutTransposedMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
MutMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
MutTransposedMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
MutMatrix::new(c, ldc),
|
||||
),
|
||||
|
||||
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
e,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
MutTransposedMatrix::new(c, ldc),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,11 +139,12 @@ unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
|
||||
a: A,
|
||||
b: B,
|
||||
beta: F,
|
||||
c: C
|
||||
)
|
||||
where E: Executor,
|
||||
c: C,
|
||||
) where
|
||||
E: Executor,
|
||||
F: Number,
|
||||
MR: Dim, NR: Dim,
|
||||
MR: Dim,
|
||||
NR: Dim,
|
||||
K: GemmKernel<F, MR, NR>,
|
||||
A: Matrix<F>,
|
||||
B: Matrix<F>,
|
||||
@ -133,7 +175,7 @@ where E: Executor,
|
||||
c.sub(j, i),
|
||||
pa,
|
||||
pb,
|
||||
i == 0
|
||||
i == 0,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -172,8 +214,8 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
pa: MutMatrix<F>,
|
||||
pb: MutMatrix<F>,
|
||||
first_time: bool,
|
||||
)
|
||||
where E: Executor,
|
||||
) where
|
||||
E: Executor,
|
||||
F: Number,
|
||||
MR: Dim,
|
||||
NR: Dim,
|
||||
@ -181,7 +223,6 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
A: Matrix<F>,
|
||||
B: Matrix<F>,
|
||||
C: MatrixMut<F>,
|
||||
|
||||
{
|
||||
let n_left = n % NR::DIM;
|
||||
let n_main = n - n_left;
|
||||
@ -190,33 +231,24 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
let m_main = m - m_left;
|
||||
|
||||
if first_time {
|
||||
e.execute(0, n_main, NR::DIM, move |j|
|
||||
K::pack_row_b(b, pb));
|
||||
e.execute(0, n_main, NR::DIM, move |_j| K::pack_row_b(b, pb));
|
||||
}
|
||||
|
||||
e.execute(0, m_main, MR::DIM, move |i|
|
||||
K::pack_row_a(a.sub_col(i), pa.sub_row(i)));
|
||||
e.execute(0, m_main, MR::DIM, move |i| {
|
||||
K::pack_row_a(a.sub_col(i), pa.sub_row(i))
|
||||
});
|
||||
|
||||
e.synchronize();
|
||||
|
||||
e.execute(0, n_main, NR::DIM, move |j| {
|
||||
// Section TL
|
||||
for i in (0..m_main).step_by(MR::DIM) {
|
||||
K::main_tl(alpha,
|
||||
pa.sub_row(i),
|
||||
pb.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i));
|
||||
K::main_tl(alpha, pa.sub_row(i), pb.sub_row(j), beta, c.sub(j, i));
|
||||
}
|
||||
|
||||
// Section TR
|
||||
for i in m_main..m {
|
||||
K::sup_tr(
|
||||
alpha,
|
||||
a.sub_col(i),
|
||||
pb.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i));
|
||||
K::sup_tr(alpha, a.sub_col(i), pb.sub_row(j), beta, c.sub(j, i));
|
||||
}
|
||||
});
|
||||
|
||||
@ -225,27 +257,15 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
let i = ji % m_main;
|
||||
|
||||
// Section BL
|
||||
K::sup_bl(
|
||||
alpha,
|
||||
pa.sub_row(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i)
|
||||
);
|
||||
K::sup_bl(alpha, pa.sub_row(i), b.sub_row(j), beta, c.sub(j, i));
|
||||
});
|
||||
|
||||
// Section BR
|
||||
for j in n_main..n {
|
||||
for i in m_main..m {
|
||||
K::sup_br(
|
||||
k,
|
||||
alpha,
|
||||
a.sub_col(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i))
|
||||
K::sup_br(k, alpha, a.sub_col(i), b.sub_row(j), beta, c.sub(j, i))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
e.synchronize();
|
||||
}
|
@ -1,16 +1,16 @@
|
||||
use super::fma::fmadd_ps;
|
||||
use super::intrinsics::*;
|
||||
use crate::kernel::params::single::{MR, NR};
|
||||
use crate::matrix::{Matrix, MutMatrix, MatrixMut};
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
use crunchy::unroll;
|
||||
|
||||
#[inline]
|
||||
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
_alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
_beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut mt00 = _mm256_setzero_ps();
|
||||
@ -75,7 +75,6 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||
pb = pb.add(BATCH * NR);
|
||||
}
|
||||
|
||||
|
||||
for _ in k_main..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
@ -192,10 +191,10 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
_alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
_beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_ps();
|
||||
|
@ -6,15 +6,16 @@ pub mod l1s;
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||
|
||||
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSupNr<f32, A5>
|
||||
where
|
||||
I: GemmKernelSupNr<f32, A5>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -29,7 +30,8 @@ impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSupMr<f32, A16>
|
||||
where
|
||||
I: GemmKernelSupMr<f32, A16>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -44,7 +46,8 @@ impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSup<f32>
|
||||
where
|
||||
I: GemmKernelSup<f32>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -60,7 +63,8 @@ impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
where
|
||||
I: GemmKernel<f32, A16, A5>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
|
@ -1,13 +1,12 @@
|
||||
use crate::kernel::params::single::NR;
|
||||
use crate::matrix::{MutMatrix, Matrix, MatrixMut};
|
||||
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
_alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
_beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut c0 = 0.0f32;
|
||||
|
@ -1,9 +1,9 @@
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
|
||||
pub struct GenericKernel;
|
||||
|
||||
@ -16,24 +16,18 @@ impl GemmKernelSupNr<f32, A5> for GenericKernel {
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
self::l3s::sgemm_sup_1x8(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a,
|
||||
pb,
|
||||
beta,
|
||||
c)
|
||||
self::l3s::sgemm_sup_1x8(pb.stride, alpha, a, pb, beta, c)
|
||||
}
|
||||
}
|
||||
|
||||
impl GemmKernelSupMr<f32, A16> for GenericKernel {
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
_alpha: f32,
|
||||
_pa: MutMatrix<f32>,
|
||||
_b: B,
|
||||
_beta: f32,
|
||||
_c: C,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
@ -66,7 +60,6 @@ impl GemmKernelSup<f32> for GenericKernel {
|
||||
}
|
||||
|
||||
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
if a.is_transposed() {
|
||||
@ -87,11 +80,11 @@ impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||
|
||||
#[inline]
|
||||
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
_alpha: f32,
|
||||
_pa: MutMatrix<f32>,
|
||||
_pb: MutMatrix<f32>,
|
||||
_beta: f32,
|
||||
_c: C,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
@ -6,8 +6,8 @@ pub mod avx;
|
||||
pub mod generic;
|
||||
pub mod sse;
|
||||
|
||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
|
||||
use crate::dim::Dim;
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||
|
||||
pub mod params {
|
||||
pub mod single {
|
||||
@ -70,15 +70,11 @@ pub trait GemmKernelSupNr<F: Number, NR: Dim> {
|
||||
);
|
||||
}
|
||||
|
||||
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
|
||||
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>:
|
||||
GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F>
|
||||
{
|
||||
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
|
||||
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
|
||||
|
||||
unsafe fn main_tl<C: MatrixMut<F>>(
|
||||
alpha: F,
|
||||
pa: MutMatrix<F>,
|
||||
pb: MutMatrix<F>,
|
||||
beta: F,
|
||||
c: C,
|
||||
);
|
||||
unsafe fn main_tl<C: MatrixMut<F>>(alpha: F, pa: MutMatrix<F>, pb: MutMatrix<F>, beta: F, c: C);
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::intrinsics::*;
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
|
@ -1,3 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::fma::fmadd_pd;
|
||||
use super::intrinsics::*;
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::fma::fmadd_ps;
|
||||
use super::intrinsics::*;
|
||||
use crate::kernel::params::single::NR;
|
||||
@ -5,10 +7,10 @@ use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
_alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
_beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut c0_3 = _mm_setzero_ps();
|
||||
@ -69,11 +71,11 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
_lda: usize,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
|
@ -6,16 +6,16 @@ mod intrinsics;
|
||||
// pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
|
||||
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||
|
||||
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
||||
where I: GemmKernelSupNr<f32, A5>
|
||||
where
|
||||
I: GemmKernelSupNr<f32, A5>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -30,7 +30,8 @@ impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
||||
where I: GemmKernelSupMr<f32, A16>
|
||||
where
|
||||
I: GemmKernelSupMr<f32, A16>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -45,7 +46,8 @@ impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
||||
where I: GemmKernelSup<f32>
|
||||
where
|
||||
I: GemmKernelSup<f32>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
@ -61,7 +63,8 @@ impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
||||
}
|
||||
|
||||
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
where
|
||||
I: GemmKernel<f32, A16, A5>,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
|
@ -1,10 +1,10 @@
|
||||
pub mod matrix;
|
||||
pub mod gemm;
|
||||
mod sgemm;
|
||||
mod aligned_alloc;
|
||||
pub mod kernel;
|
||||
pub mod dim;
|
||||
pub mod executor;
|
||||
pub mod gemm;
|
||||
pub mod kernel;
|
||||
pub mod matrix;
|
||||
mod sgemm;
|
||||
|
||||
#[cfg(test)]
|
||||
extern crate blas;
|
||||
@ -13,5 +13,4 @@ extern crate openblas;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
|
||||
pub use crate::sgemm::sgemm;
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
|
||||
|
||||
impl Number for f32 {}
|
||||
@ -32,7 +31,7 @@ pub trait MatrixMut<F: Number>: Matrix<F> {
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ConstMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *const F
|
||||
pub ptr: *const F,
|
||||
}
|
||||
|
||||
unsafe impl<F: Number> Send for ConstMatrix<F> {}
|
||||
@ -49,7 +48,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index(row, col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,7 +56,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col(col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +64,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row(row),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,7 +127,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ConstTransposedMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *const F
|
||||
pub ptr: *const F,
|
||||
}
|
||||
|
||||
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
|
||||
@ -145,7 +144,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index(row, col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,7 +152,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col(col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,7 +160,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row(row),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -224,7 +223,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MutMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut F
|
||||
pub ptr: *mut F,
|
||||
}
|
||||
|
||||
unsafe impl<F: Number> Send for MutMatrix<F> {}
|
||||
@ -241,7 +240,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index_mut(row, col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,7 +248,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col_mut(col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -257,7 +256,7 @@ impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row_mut(row),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -347,7 +346,7 @@ impl<F: Number> MatrixMut<F> for MutMatrix<F> {
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MutTransposedMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut F
|
||||
pub ptr: *mut F,
|
||||
}
|
||||
|
||||
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
|
||||
@ -364,7 +363,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index_mut(row, col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -372,7 +371,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col_mut(col),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -380,7 +379,7 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row_mut(row),
|
||||
stride: self.stride
|
||||
stride: self.stride,
|
||||
}
|
||||
}
|
||||
|
||||
@ -440,7 +439,6 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||
@ -467,5 +465,3 @@ impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
12
src/sgemm.rs
12
src/sgemm.rs
@ -1,8 +1,7 @@
|
||||
|
||||
use crate::kernel::avx::AvxKernel;
|
||||
use crate::kernel::generic::GenericKernel;
|
||||
use crate::dim::*;
|
||||
use crate::executor::Executor;
|
||||
use crate::kernel::avx::AvxKernel;
|
||||
use crate::kernel::generic::GenericKernel;
|
||||
|
||||
pub unsafe fn sgemm<E: Executor>(
|
||||
e: &E,
|
||||
@ -22,9 +21,6 @@ pub unsafe fn sgemm<E: Executor>(
|
||||
ldc: usize,
|
||||
) {
|
||||
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
|
||||
e,
|
||||
transa, transb, transc,
|
||||
m, n, k, alpha,
|
||||
a, lda, b, ldb,
|
||||
beta, c, ldc);
|
||||
e, transa, transb, transc, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
);
|
||||
}
|
@ -1,11 +1,9 @@
|
||||
use rand::Rng;
|
||||
|
||||
const M_LEN: usize = 17;
|
||||
const N_LEN: usize = 6;
|
||||
const K_LEN: usize = 4;
|
||||
|
||||
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let _rng = rand::thread_rng();
|
||||
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
|
||||
@ -38,7 +36,6 @@ fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
(a, a_t, b, b_t)
|
||||
}
|
||||
|
||||
|
||||
unsafe fn sgemm_ref(
|
||||
m: usize,
|
||||
n: usize,
|
||||
@ -75,7 +72,6 @@ unsafe fn sgemm_ref(
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_nn() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
@ -133,7 +129,6 @@ fn test_sgemm_nn() {
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
@ -160,16 +155,24 @@ fn test_sgemm_nn() {
|
||||
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
assert!(
|
||||
feq(a, b),
|
||||
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||
row,
|
||||
col,
|
||||
a,
|
||||
row,
|
||||
col,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_nt() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, _, b, b_t) = make_matrices();
|
||||
let (a, _, _b, b_t) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
@ -216,16 +219,24 @@ fn test_sgemm_nt() {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
assert!(
|
||||
feq(a, b),
|
||||
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||
row,
|
||||
col,
|
||||
a,
|
||||
row,
|
||||
col,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_tn() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, a_t, b, _) = make_matrices();
|
||||
let (_a, a_t, b, _) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
@ -272,16 +283,24 @@ fn test_sgemm_tn() {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
assert!(
|
||||
feq(a, b),
|
||||
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||
row,
|
||||
col,
|
||||
a,
|
||||
row,
|
||||
col,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_tt() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, a_t, b, b_t) = make_matrices();
|
||||
let (_a, a_t, _b, b_t) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
@ -328,7 +347,16 @@ fn test_sgemm_tt() {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
assert!(
|
||||
feq(a, b),
|
||||
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||
row,
|
||||
col,
|
||||
a,
|
||||
row,
|
||||
col,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user