Fix fmt and some warnings #1
14
.drone.yml
Normal file
14
.drone.yml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
kind: pipeline
|
||||||
|
name: default
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: build
|
||||||
|
image: rust
|
||||||
|
commands:
|
||||||
|
- cargo build --verbose --all
|
||||||
|
|
||||||
|
- name: fmt-check
|
||||||
|
image: rust
|
||||||
|
commands:
|
||||||
|
- rustup component add rustfmt
|
||||||
|
- cargo fmt --all -- --check
|
2
build.rs
2
build.rs
@ -4,4 +4,4 @@ fn main() {
|
|||||||
// and run this.
|
// and run this.
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
println!("cargo:rustc-link-lib={}=openblas", "dylib");
|
println!("cargo:rustc-link-lib={}=openblas", "dylib");
|
||||||
}
|
}
|
||||||
|
40
src/dim.rs
40
src/dim.rs
@ -2,49 +2,49 @@ pub trait Dim {
|
|||||||
const DIM: usize;
|
const DIM: usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct A1;
|
pub struct A1;
|
||||||
impl Dim for A1 {
|
impl Dim for A1 {
|
||||||
const DIM: usize = 1;
|
const DIM: usize = 1;
|
||||||
}
|
}
|
||||||
pub struct A2;
|
pub struct A2;
|
||||||
impl Dim for A2 {
|
impl Dim for A2 {
|
||||||
const DIM: usize = 2;
|
const DIM: usize = 2;
|
||||||
}
|
}
|
||||||
pub struct A3;
|
pub struct A3;
|
||||||
impl Dim for A3 {
|
impl Dim for A3 {
|
||||||
const DIM: usize = 3;
|
const DIM: usize = 3;
|
||||||
}
|
}
|
||||||
pub struct A4;
|
pub struct A4;
|
||||||
impl Dim for A4 {
|
impl Dim for A4 {
|
||||||
const DIM: usize = 4;
|
const DIM: usize = 4;
|
||||||
}
|
}
|
||||||
pub struct A5;
|
pub struct A5;
|
||||||
impl Dim for A5 {
|
impl Dim for A5 {
|
||||||
const DIM: usize = 5;
|
const DIM: usize = 5;
|
||||||
}
|
}
|
||||||
pub struct A6;
|
pub struct A6;
|
||||||
impl Dim for A6 {
|
impl Dim for A6 {
|
||||||
const DIM: usize = 6;
|
const DIM: usize = 6;
|
||||||
}
|
}
|
||||||
pub struct A7;
|
pub struct A7;
|
||||||
impl Dim for A7 {
|
impl Dim for A7 {
|
||||||
const DIM: usize = 7;
|
const DIM: usize = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct A8;
|
pub struct A8;
|
||||||
impl Dim for A8 {
|
impl Dim for A8 {
|
||||||
const DIM: usize = 8;
|
const DIM: usize = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct A16;
|
pub struct A16;
|
||||||
impl Dim for A16 {
|
impl Dim for A16 {
|
||||||
const DIM: usize = 16;
|
const DIM: usize = 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
macro_rules! gen {
|
macro_rules! gen {
|
||||||
(($name: ident, $val: lit)) => {
|
(($name: ident, $val: lit)) => {
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,4 +72,4 @@ gen! {
|
|||||||
(A28, 28),
|
(A28, 28),
|
||||||
(A30, 30),
|
(A30, 30),
|
||||||
(A32, 32),
|
(A32, 32),
|
||||||
}*/
|
}*/
|
||||||
|
@ -42,19 +42,18 @@ impl Executor for RayonExecutor {
|
|||||||
(start..end)
|
(start..end)
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map(|x| x * step)
|
.map(|x| x * step)
|
||||||
.for_each(|x|f(x));
|
.for_each(|x| f(x));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub struct ThreadPoolExecutor {
|
pub struct ThreadPoolExecutor {
|
||||||
thread_pool: threadpool::ThreadPool
|
thread_pool: threadpool::ThreadPool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ThreadPoolExecutor {
|
impl ThreadPoolExecutor {
|
||||||
pub fn new() -> ThreadPoolExecutor {
|
pub fn new() -> ThreadPoolExecutor {
|
||||||
ThreadPoolExecutor {
|
ThreadPoolExecutor {
|
||||||
thread_pool: threadpool::Builder::new().build()
|
thread_pool: threadpool::Builder::new().build(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
186
src/gemm.rs
186
src/gemm.rs
@ -1,10 +1,12 @@
|
|||||||
use crate::aligned_alloc;
|
use crate::aligned_alloc;
|
||||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
|
||||||
use crate::kernel::params::single::{NC, MC, KC};
|
|
||||||
use crate::kernel;
|
|
||||||
use crate::kernel::GemmKernel;
|
|
||||||
use crate::dim::Dim;
|
use crate::dim::Dim;
|
||||||
use crate::executor::Executor;
|
use crate::executor::Executor;
|
||||||
|
use crate::kernel;
|
||||||
|
use crate::kernel::params::single::{KC, MC, NC};
|
||||||
|
use crate::kernel::GemmKernel;
|
||||||
|
use crate::matrix::{
|
||||||
|
ConstMatrix, ConstTransposedMatrix, Matrix, MatrixMut, MutMatrix, MutTransposedMatrix, Number,
|
||||||
|
};
|
||||||
|
|
||||||
pub unsafe fn gemm<E, F, K, MR, NR>(
|
pub unsafe fn gemm<E, F, K, MR, NR>(
|
||||||
e: &E,
|
e: &E,
|
||||||
@ -22,70 +24,109 @@ pub unsafe fn gemm<E, F, K, MR, NR>(
|
|||||||
beta: F,
|
beta: F,
|
||||||
c: *mut F,
|
c: *mut F,
|
||||||
ldc: usize,
|
ldc: usize,
|
||||||
)
|
) where
|
||||||
where E: Executor,
|
E: Executor,
|
||||||
F: Number,
|
F: Number,
|
||||||
MR: Dim, NR: Dim,
|
MR: Dim,
|
||||||
K: GemmKernel<F, MR, NR>,
|
NR: Dim,
|
||||||
|
K: GemmKernel<F, MR, NR>,
|
||||||
{
|
{
|
||||||
match (transa, transb, transc) {
|
match (transa, transb, transc) {
|
||||||
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstMatrix::new(a, lda),
|
ConstMatrix::new(a, lda),
|
||||||
ConstMatrix::new(b, ldb),
|
ConstMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutMatrix::new(c, ldc)),
|
MutMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstMatrix::new(a, lda),
|
ConstMatrix::new(a, lda),
|
||||||
ConstMatrix::new(b, ldb),
|
ConstMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
MutTransposedMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstMatrix::new(a, lda),
|
ConstMatrix::new(a, lda),
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutMatrix::new(c, ldc)),
|
MutMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstMatrix::new(a, lda),
|
ConstMatrix::new(a, lda),
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
MutTransposedMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstTransposedMatrix::new(a, lda),
|
ConstTransposedMatrix::new(a, lda),
|
||||||
ConstMatrix::new(b, ldb),
|
ConstMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutMatrix::new(c, ldc)),
|
MutMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstTransposedMatrix::new(a, lda),
|
ConstTransposedMatrix::new(a, lda),
|
||||||
ConstMatrix::new(b, ldb),
|
ConstMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
MutTransposedMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstTransposedMatrix::new(a, lda),
|
ConstTransposedMatrix::new(a, lda),
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutMatrix::new(c, ldc)),
|
MutMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
e, m, n, k, alpha,
|
e,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
ConstTransposedMatrix::new(a, lda),
|
ConstTransposedMatrix::new(a, lda),
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
beta,
|
beta,
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
MutTransposedMatrix::new(c, ldc),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,15 +139,16 @@ unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
|
|||||||
a: A,
|
a: A,
|
||||||
b: B,
|
b: B,
|
||||||
beta: F,
|
beta: F,
|
||||||
c: C
|
c: C,
|
||||||
)
|
) where
|
||||||
where E: Executor,
|
E: Executor,
|
||||||
F: Number,
|
F: Number,
|
||||||
MR: Dim, NR: Dim,
|
MR: Dim,
|
||||||
K: GemmKernel<F, MR, NR>,
|
NR: Dim,
|
||||||
A: Matrix<F>,
|
K: GemmKernel<F, MR, NR>,
|
||||||
B: Matrix<F>,
|
A: Matrix<F>,
|
||||||
C: MatrixMut<F>,
|
B: Matrix<F>,
|
||||||
|
C: MatrixMut<F>,
|
||||||
{
|
{
|
||||||
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<F>());
|
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<F>());
|
||||||
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<F>());
|
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<F>());
|
||||||
@ -122,7 +164,7 @@ where E: Executor,
|
|||||||
let pb = MutMatrix::new(packed_b.ptr::<F>(), p_b);
|
let pb = MutMatrix::new(packed_b.ptr::<F>(), p_b);
|
||||||
|
|
||||||
inner_kernel::<E, F, K, MR, NR, _, _, _>(
|
inner_kernel::<E, F, K, MR, NR, _, _, _>(
|
||||||
e,
|
e,
|
||||||
i_b,
|
i_b,
|
||||||
j_b,
|
j_b,
|
||||||
p_b,
|
p_b,
|
||||||
@ -133,7 +175,7 @@ where E: Executor,
|
|||||||
c.sub(j, i),
|
c.sub(j, i),
|
||||||
pa,
|
pa,
|
||||||
pb,
|
pb,
|
||||||
i == 0
|
i == 0,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,7 +183,7 @@ where E: Executor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// | MR |
|
// | MR |
|
||||||
// +-----------------------------+----+
|
// +-----------------------------+----+
|
||||||
// | : | |
|
// | : | |
|
||||||
// | TL : | TR |
|
// | TL : | TR |
|
||||||
@ -172,16 +214,15 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
|||||||
pa: MutMatrix<F>,
|
pa: MutMatrix<F>,
|
||||||
pb: MutMatrix<F>,
|
pb: MutMatrix<F>,
|
||||||
first_time: bool,
|
first_time: bool,
|
||||||
)
|
) where
|
||||||
where E: Executor,
|
E: Executor,
|
||||||
F: Number,
|
F: Number,
|
||||||
MR: Dim,
|
MR: Dim,
|
||||||
NR: Dim,
|
NR: Dim,
|
||||||
K: kernel::GemmKernel<F, MR, NR>,
|
K: kernel::GemmKernel<F, MR, NR>,
|
||||||
A: Matrix<F>,
|
A: Matrix<F>,
|
||||||
B: Matrix<F>,
|
B: Matrix<F>,
|
||||||
C: MatrixMut<F>,
|
C: MatrixMut<F>,
|
||||||
|
|
||||||
{
|
{
|
||||||
let n_left = n % NR::DIM;
|
let n_left = n % NR::DIM;
|
||||||
let n_main = n - n_left;
|
let n_main = n - n_left;
|
||||||
@ -190,33 +231,24 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
|||||||
let m_main = m - m_left;
|
let m_main = m - m_left;
|
||||||
|
|
||||||
if first_time {
|
if first_time {
|
||||||
e.execute(0, n_main, NR::DIM, move |j|
|
e.execute(0, n_main, NR::DIM, move |_j| K::pack_row_b(b, pb));
|
||||||
K::pack_row_b(b, pb));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.execute(0, m_main, MR::DIM, move |i|
|
e.execute(0, m_main, MR::DIM, move |i| {
|
||||||
K::pack_row_a(a.sub_col(i), pa.sub_row(i)));
|
K::pack_row_a(a.sub_col(i), pa.sub_row(i))
|
||||||
|
});
|
||||||
|
|
||||||
e.synchronize();
|
e.synchronize();
|
||||||
|
|
||||||
e.execute(0, n_main, NR::DIM, move |j| {
|
e.execute(0, n_main, NR::DIM, move |j| {
|
||||||
// Section TL
|
// Section TL
|
||||||
for i in (0..m_main).step_by(MR::DIM) {
|
for i in (0..m_main).step_by(MR::DIM) {
|
||||||
K::main_tl(alpha,
|
K::main_tl(alpha, pa.sub_row(i), pb.sub_row(j), beta, c.sub(j, i));
|
||||||
pa.sub_row(i),
|
|
||||||
pb.sub_row(j),
|
|
||||||
beta,
|
|
||||||
c.sub(j, i));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Section TR
|
// Section TR
|
||||||
for i in m_main..m {
|
for i in m_main..m {
|
||||||
K::sup_tr(
|
K::sup_tr(alpha, a.sub_col(i), pb.sub_row(j), beta, c.sub(j, i));
|
||||||
alpha,
|
|
||||||
a.sub_col(i),
|
|
||||||
pb.sub_row(j),
|
|
||||||
beta,
|
|
||||||
c.sub(j, i));
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -225,27 +257,15 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
|||||||
let i = ji % m_main;
|
let i = ji % m_main;
|
||||||
|
|
||||||
// Section BL
|
// Section BL
|
||||||
K::sup_bl(
|
K::sup_bl(alpha, pa.sub_row(i), b.sub_row(j), beta, c.sub(j, i));
|
||||||
alpha,
|
|
||||||
pa.sub_row(i),
|
|
||||||
b.sub_row(j),
|
|
||||||
beta,
|
|
||||||
c.sub(j, i)
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Section BR
|
// Section BR
|
||||||
for j in n_main..n {
|
for j in n_main..n {
|
||||||
for i in m_main..m {
|
for i in m_main..m {
|
||||||
K::sup_br(
|
K::sup_br(k, alpha, a.sub_col(i), b.sub_row(j), beta, c.sub(j, i))
|
||||||
k,
|
|
||||||
alpha,
|
|
||||||
a.sub_col(i),
|
|
||||||
b.sub_row(j),
|
|
||||||
beta,
|
|
||||||
c.sub(j, i))
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
e.synchronize();
|
e.synchronize();
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,7 @@ pub unsafe fn sdot(
|
|||||||
let mut acc1 = _mm256_setzero_ps();
|
let mut acc1 = _mm256_setzero_ps();
|
||||||
let mut acc2 = _mm256_setzero_ps();
|
let mut acc2 = _mm256_setzero_ps();
|
||||||
let mut acc3 = _mm256_setzero_ps();
|
let mut acc3 = _mm256_setzero_ps();
|
||||||
|
|
||||||
for _ in 0..n / 32 {
|
for _ in 0..n / 32 {
|
||||||
let x0 = _mm256_loadu_ps(x);
|
let x0 = _mm256_loadu_ps(x);
|
||||||
let y0 = _mm256_loadu_ps(y);
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
use super::fma::fmadd_ps;
|
use super::fma::fmadd_ps;
|
||||||
use super::intrinsics::*;
|
use super::intrinsics::*;
|
||||||
use crate::kernel::params::single::{MR, NR};
|
use crate::kernel::params::single::{MR, NR};
|
||||||
use crate::matrix::{Matrix, MutMatrix, MatrixMut};
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||||
use crunchy::unroll;
|
use crunchy::unroll;
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
pa: MutMatrix<f32>,
|
pa: MutMatrix<f32>,
|
||||||
pb: MutMatrix<f32>,
|
pb: MutMatrix<f32>,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
c: C,
|
||||||
) {
|
) {
|
||||||
let mut mt00 = _mm256_setzero_ps();
|
let mut mt00 = _mm256_setzero_ps();
|
||||||
@ -21,7 +21,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// let mut mt05 = _mm256_setzero_ps();
|
// let mut mt05 = _mm256_setzero_ps();
|
||||||
// let mut mt06 = _mm256_setzero_ps();
|
// let mut mt06 = _mm256_setzero_ps();
|
||||||
// let mut mt07 = _mm256_setzero_ps();
|
// let mut mt07 = _mm256_setzero_ps();
|
||||||
|
|
||||||
let mut mt10 = _mm256_setzero_ps();
|
let mut mt10 = _mm256_setzero_ps();
|
||||||
let mut mt11 = _mm256_setzero_ps();
|
let mut mt11 = _mm256_setzero_ps();
|
||||||
let mut mt12 = _mm256_setzero_ps();
|
let mut mt12 = _mm256_setzero_ps();
|
||||||
@ -30,7 +30,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// let mut mt15 = _mm256_setzero_ps();
|
// let mut mt15 = _mm256_setzero_ps();
|
||||||
// let mut mt16 = _mm256_setzero_ps();
|
// let mut mt16 = _mm256_setzero_ps();
|
||||||
// let mut mt17 = _mm256_setzero_ps();
|
// let mut mt17 = _mm256_setzero_ps();
|
||||||
|
|
||||||
let mut pa = pa.ptr();
|
let mut pa = pa.ptr();
|
||||||
let mut pb = pb.ptr();
|
let mut pb = pb.ptr();
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||||
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||||
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||||
|
|
||||||
mt00 = fmadd_ps(a0, b0, mt00);
|
mt00 = fmadd_ps(a0, b0, mt00);
|
||||||
mt01 = fmadd_ps(a0, b1, mt01);
|
mt01 = fmadd_ps(a0, b1, mt01);
|
||||||
mt02 = fmadd_ps(a0, b2, mt02);
|
mt02 = fmadd_ps(a0, b2, mt02);
|
||||||
@ -75,7 +75,6 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
pb = pb.add(BATCH * NR);
|
pb = pb.add(BATCH * NR);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for _ in k_main..k {
|
for _ in k_main..k {
|
||||||
let a0 = _mm256_load_ps(pa);
|
let a0 = _mm256_load_ps(pa);
|
||||||
let a1 = _mm256_load_ps(pa.add(8));
|
let a1 = _mm256_load_ps(pa.add(8));
|
||||||
@ -88,7 +87,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||||
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||||
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||||
|
|
||||||
mt00 = fmadd_ps(a0, b0, mt00);
|
mt00 = fmadd_ps(a0, b0, mt00);
|
||||||
mt01 = fmadd_ps(a0, b1, mt01);
|
mt01 = fmadd_ps(a0, b1, mt01);
|
||||||
mt02 = fmadd_ps(a0, b2, mt02);
|
mt02 = fmadd_ps(a0, b2, mt02);
|
||||||
@ -160,7 +159,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05);
|
// // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05);
|
||||||
// // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06);
|
// // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06);
|
||||||
// // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07);
|
// // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07);
|
||||||
|
|
||||||
// mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
// mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||||
// mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
// mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||||
// mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
// mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||||
@ -179,7 +178,7 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
// _mm256_storeu_ps(ccol5, mt05);
|
// _mm256_storeu_ps(ccol5, mt05);
|
||||||
// _mm256_storeu_ps(ccol6, mt06);
|
// _mm256_storeu_ps(ccol6, mt06);
|
||||||
// _mm256_storeu_ps(ccol7, mt07);
|
// _mm256_storeu_ps(ccol7, mt07);
|
||||||
|
|
||||||
_mm256_storeu_ps(ccol0.add(8), mt10);
|
_mm256_storeu_ps(ccol0.add(8), mt10);
|
||||||
_mm256_storeu_ps(ccol1.add(8), mt11);
|
_mm256_storeu_ps(ccol1.add(8), mt11);
|
||||||
_mm256_storeu_ps(ccol2.add(8), mt12);
|
_mm256_storeu_ps(ccol2.add(8), mt12);
|
||||||
@ -192,10 +191,10 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
|||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
pa: MutMatrix<f32>,
|
pa: MutMatrix<f32>,
|
||||||
b: B,
|
b: B,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
c: C,
|
||||||
) {
|
) {
|
||||||
let mut mt0 = _mm256_setzero_ps();
|
let mut mt0 = _mm256_setzero_ps();
|
||||||
@ -250,4 +249,4 @@ pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut
|
|||||||
pa = pa.add(16);
|
pa = pa.add(16);
|
||||||
a = a.add(lda);
|
a = a.add(lda);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,15 +6,16 @@ pub mod l1s;
|
|||||||
pub mod l3d;
|
pub mod l3d;
|
||||||
pub mod l3s;
|
pub mod l3s;
|
||||||
|
|
||||||
use core::marker::PhantomData;
|
|
||||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
|
||||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
|
||||||
use crate::dim::*;
|
use crate::dim::*;
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||||
|
|
||||||
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
||||||
where I: GemmKernelSupNr<f32, A5>
|
where
|
||||||
|
I: GemmKernelSupNr<f32, A5>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -26,10 +27,11 @@ impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
|||||||
) {
|
) {
|
||||||
I::sup_tr(alpha, a, pb, beta, c);
|
I::sup_tr(alpha, a, pb, beta, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
||||||
where I: GemmKernelSupMr<f32, A16>
|
where
|
||||||
|
I: GemmKernelSupMr<f32, A16>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -43,8 +45,9 @@ impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
||||||
where I: GemmKernelSup<f32>
|
where
|
||||||
|
I: GemmKernelSup<f32>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -59,12 +62,13 @@ impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
||||||
where I: GemmKernel<f32, A16, A5>
|
where
|
||||||
|
I: GemmKernel<f32, A16, A5>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||||
if a.is_transposed() {
|
if a.is_transposed() {
|
||||||
I::pack_row_a(a, pa);
|
I::pack_row_a(a, pa);
|
||||||
} else {
|
} else {
|
||||||
self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
||||||
@ -86,4 +90,4 @@ impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
|||||||
) {
|
) {
|
||||||
self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c);
|
self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
use crate::kernel::params::single::NR;
|
use crate::kernel::params::single::NR;
|
||||||
use crate::matrix::{MutMatrix, Matrix, MatrixMut};
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||||
|
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
a: A,
|
a: A,
|
||||||
pb: MutMatrix<f32>,
|
pb: MutMatrix<f32>,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
c: C,
|
||||||
) {
|
) {
|
||||||
let mut c0 = 0.0f32;
|
let mut c0 = 0.0f32;
|
||||||
@ -46,7 +45,7 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
|||||||
// // c5 *= alpha;
|
// // c5 *= alpha;
|
||||||
// // c6 *= alpha;
|
// // c6 *= alpha;
|
||||||
// // c7 *= alpha;
|
// // c7 *= alpha;
|
||||||
|
|
||||||
let ccol0 = c.ptr_mut();
|
let ccol0 = c.ptr_mut();
|
||||||
let ccol1 = c.row_mut(1);
|
let ccol1 = c.row_mut(1);
|
||||||
let ccol2 = c.row_mut(2);
|
let ccol2 = c.row_mut(2);
|
||||||
@ -139,4 +138,4 @@ pub(crate) unsafe fn sgemm_pa_t(k: usize, a: *const f32, lda: usize, pa: *mut f3
|
|||||||
a = a.add(1);
|
a = a.add(1);
|
||||||
pa = pa.add(MR);
|
pa = pa.add(MR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
pub mod l3d;
|
pub mod l3d;
|
||||||
pub mod l3s;
|
pub mod l3s;
|
||||||
|
|
||||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
|
||||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
|
||||||
use crate::dim::*;
|
use crate::dim::*;
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||||
|
|
||||||
pub struct GenericKernel;
|
pub struct GenericKernel;
|
||||||
|
|
||||||
@ -16,24 +16,18 @@ impl GemmKernelSupNr<f32, A5> for GenericKernel {
|
|||||||
beta: f32,
|
beta: f32,
|
||||||
c: C,
|
c: C,
|
||||||
) {
|
) {
|
||||||
self::l3s::sgemm_sup_1x8(
|
self::l3s::sgemm_sup_1x8(pb.stride, alpha, a, pb, beta, c)
|
||||||
pb.stride,
|
|
||||||
alpha,
|
|
||||||
a,
|
|
||||||
pb,
|
|
||||||
beta,
|
|
||||||
c)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GemmKernelSupMr<f32, A16> for GenericKernel {
|
impl GemmKernelSupMr<f32, A16> for GenericKernel {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
pa: MutMatrix<f32>,
|
_pa: MutMatrix<f32>,
|
||||||
b: B,
|
_b: B,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
_c: C,
|
||||||
) {
|
) {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
@ -66,10 +60,9 @@ impl GemmKernelSup<f32> for GenericKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||||
if a.is_transposed() {
|
if a.is_transposed() {
|
||||||
self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
@ -78,7 +71,7 @@ impl GemmKernel<f32, A16, A5> for GenericKernel {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
|
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
|
||||||
if b.is_transposed() {
|
if b.is_transposed() {
|
||||||
self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
||||||
} else {
|
} else {
|
||||||
self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
||||||
@ -87,12 +80,12 @@ impl GemmKernel<f32, A16, A5> for GenericKernel {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn main_tl<C: MatrixMut<f32>>(
|
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
pa: MutMatrix<f32>,
|
_pa: MutMatrix<f32>,
|
||||||
pb: MutMatrix<f32>,
|
_pb: MutMatrix<f32>,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
_c: C,
|
||||||
) {
|
) {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,14 +6,14 @@ pub mod avx;
|
|||||||
pub mod generic;
|
pub mod generic;
|
||||||
pub mod sse;
|
pub mod sse;
|
||||||
|
|
||||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
|
|
||||||
use crate::dim::Dim;
|
use crate::dim::Dim;
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||||
|
|
||||||
pub mod params {
|
pub mod params {
|
||||||
pub mod single {
|
pub mod single {
|
||||||
pub const MC: usize = 128;
|
pub const MC: usize = 128;
|
||||||
pub const KC: usize = 256;
|
pub const KC: usize = 256;
|
||||||
pub const NC: usize = 8*1024;
|
pub const NC: usize = 8 * 1024;
|
||||||
pub const MR: usize = 16;
|
pub const MR: usize = 16;
|
||||||
pub const NR: usize = 5;
|
pub const NR: usize = 5;
|
||||||
}
|
}
|
||||||
@ -70,15 +70,11 @@ pub trait GemmKernelSupNr<F: Number, NR: Dim> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
|
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>:
|
||||||
|
GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F>
|
||||||
|
{
|
||||||
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
|
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
|
||||||
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
|
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
|
||||||
|
|
||||||
unsafe fn main_tl<C: MatrixMut<F>>(
|
unsafe fn main_tl<C: MatrixMut<F>>(alpha: F, pa: MutMatrix<F>, pb: MutMatrix<F>, beta: F, c: C);
|
||||||
alpha: F,
|
}
|
||||||
pa: MutMatrix<F>,
|
|
||||||
pb: MutMatrix<F>,
|
|
||||||
beta: F,
|
|
||||||
c: C,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use super::intrinsics::*;
|
use super::intrinsics::*;
|
||||||
|
|
||||||
#[cfg(target_feature = "fma")]
|
#[cfg(target_feature = "fma")]
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use super::fma::fmadd_pd;
|
use super::fma::fmadd_pd;
|
||||||
use super::intrinsics::*;
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use super::fma::fmadd_ps;
|
use super::fma::fmadd_ps;
|
||||||
use super::intrinsics::*;
|
use super::intrinsics::*;
|
||||||
use crate::kernel::params::single::NR;
|
use crate::kernel::params::single::NR;
|
||||||
@ -5,10 +7,10 @@ use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
|||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
a: A,
|
a: A,
|
||||||
pb: MutMatrix<f32>,
|
pb: MutMatrix<f32>,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: C,
|
c: C,
|
||||||
) {
|
) {
|
||||||
let mut c0_3 = _mm_setzero_ps();
|
let mut c0_3 = _mm_setzero_ps();
|
||||||
@ -23,7 +25,7 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
|||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
let a0 = *a.ptr();
|
let a0 = *a.ptr();
|
||||||
let a0_simd = _mm_broadcast_ss(&*a.ptr());
|
let a0_simd = _mm_broadcast_ss(&*a.ptr());
|
||||||
|
|
||||||
c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3);
|
c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3);
|
||||||
c4 += *pb.col(4) * a0;
|
c4 += *pb.col(4) * a0;
|
||||||
// c5 += *pb.add(5) * a0;
|
// c5 += *pb.add(5) * a0;
|
||||||
@ -42,7 +44,7 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
|||||||
// c5 *= alpha;
|
// c5 *= alpha;
|
||||||
// c6 *= alpha;
|
// c6 *= alpha;
|
||||||
// c7 *= alpha;
|
// c7 *= alpha;
|
||||||
|
|
||||||
let ccol0_3 = c.ptr_mut();
|
let ccol0_3 = c.ptr_mut();
|
||||||
let ccol4 = c.row_mut(4);
|
let ccol4 = c.row_mut(4);
|
||||||
// let ccol5 = c.add(ldc * 5);
|
// let ccol5 = c.add(ldc * 5);
|
||||||
@ -69,11 +71,11 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
|||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup_1x8_t(
|
pub(crate) unsafe fn sgemm_sup_1x8_t(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
_alpha: f32,
|
||||||
a: *const f32,
|
a: *const f32,
|
||||||
lda: usize,
|
_lda: usize,
|
||||||
pb: *const f32,
|
pb: *const f32,
|
||||||
beta: f32,
|
_beta: f32,
|
||||||
c: *mut f32,
|
c: *mut f32,
|
||||||
ldc: usize,
|
ldc: usize,
|
||||||
) {
|
) {
|
||||||
@ -88,8 +90,8 @@ pub(crate) unsafe fn sgemm_sup_1x8_t(
|
|||||||
|
|
||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
let a0 = *a;
|
let a0 = *a;
|
||||||
let a0_simd = _mm_set_ps1(a0);
|
let a0_simd = _mm_set_ps1(a0);
|
||||||
|
|
||||||
c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03);
|
c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03);
|
||||||
c4 += *pb.add(4) * a0;
|
c4 += *pb.add(4) * a0;
|
||||||
// c5 += *pb.add(5) * a0;
|
// c5 += *pb.add(5) * a0;
|
||||||
@ -108,7 +110,7 @@ pub(crate) unsafe fn sgemm_sup_1x8_t(
|
|||||||
// c5 *= alpha;
|
// c5 *= alpha;
|
||||||
// c6 *= alpha;
|
// c6 *= alpha;
|
||||||
// c7 *= alpha;
|
// c7 *= alpha;
|
||||||
|
|
||||||
// let ccol0 = c;
|
// let ccol0 = c;
|
||||||
// let ccol1 = c.add(ldc);
|
// let ccol1 = c.add(ldc);
|
||||||
// let ccol2 = c.add(ldc * 2);
|
// let ccol2 = c.add(ldc * 2);
|
||||||
@ -175,4 +177,4 @@ pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f
|
|||||||
|
|
||||||
pb = pb.add(NR);
|
pb = pb.add(NR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,16 +6,16 @@ mod intrinsics;
|
|||||||
// pub mod l3d;
|
// pub mod l3d;
|
||||||
pub mod l3s;
|
pub mod l3s;
|
||||||
|
|
||||||
use core::marker::PhantomData;
|
|
||||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
|
||||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
|
||||||
use crate::dim::*;
|
use crate::dim::*;
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSup, GemmKernelSupMr, GemmKernelSupNr};
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix, Number};
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||||
|
|
||||||
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
||||||
where I: GemmKernelSupNr<f32, A5>
|
where
|
||||||
|
I: GemmKernelSupNr<f32, A5>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -27,10 +27,11 @@ impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
|||||||
) {
|
) {
|
||||||
I::sup_tr(alpha, a, pb, beta, c);
|
I::sup_tr(alpha, a, pb, beta, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
||||||
where I: GemmKernelSupMr<f32, A16>
|
where
|
||||||
|
I: GemmKernelSupMr<f32, A16>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -44,8 +45,9 @@ impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
||||||
where I: GemmKernelSup<f32>
|
where
|
||||||
|
I: GemmKernelSup<f32>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
@ -60,8 +62,9 @@ impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
||||||
where I: GemmKernel<f32, A16, A5>
|
where
|
||||||
|
I: GemmKernel<f32, A16, A5>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||||
@ -83,4 +86,4 @@ impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
|||||||
) {
|
) {
|
||||||
I::main_tl(alpha, pa, pb, beta, c);
|
I::main_tl(alpha, pa, pb, beta, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
pub mod matrix;
|
|
||||||
pub mod gemm;
|
|
||||||
mod sgemm;
|
|
||||||
mod aligned_alloc;
|
mod aligned_alloc;
|
||||||
pub mod kernel;
|
|
||||||
pub mod dim;
|
pub mod dim;
|
||||||
pub mod executor;
|
pub mod executor;
|
||||||
|
pub mod gemm;
|
||||||
|
pub mod kernel;
|
||||||
|
pub mod matrix;
|
||||||
|
mod sgemm;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
extern crate blas;
|
extern crate blas;
|
||||||
@ -13,5 +13,4 @@ extern crate openblas;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
|
|
||||||
pub use crate::sgemm::sgemm;
|
pub use crate::sgemm::sgemm;
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
|
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
|
||||||
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
|
|
||||||
|
|
||||||
impl Number for f32 {}
|
impl Number for f32 {}
|
||||||
impl Number for f64 {}
|
impl Number for f64 {}
|
||||||
@ -32,7 +31,7 @@ pub trait MatrixMut<F: Number>: Matrix<F> {
|
|||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct ConstMatrix<F: Number> {
|
pub struct ConstMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *const F
|
pub ptr: *const F,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<F: Number> Send for ConstMatrix<F> {}
|
unsafe impl<F: Number> Send for ConstMatrix<F> {}
|
||||||
@ -47,31 +46,31 @@ impl<F: Number> ConstMatrix<F> {
|
|||||||
impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.index(row, col),
|
ptr: self.index(row, col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.col(col),
|
ptr: self.col(col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.row(row),
|
ptr: self.row(row),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -128,7 +127,7 @@ impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
|||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct ConstTransposedMatrix<F: Number> {
|
pub struct ConstTransposedMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *const F
|
pub ptr: *const F,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
|
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
|
||||||
@ -143,25 +142,25 @@ impl<F: Number> ConstTransposedMatrix<F> {
|
|||||||
impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.index(row, col),
|
ptr: self.index(row, col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.col(col),
|
ptr: self.col(col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.row(row),
|
ptr: self.row(row),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +223,7 @@ impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
|||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct MutMatrix<F: Number> {
|
pub struct MutMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *mut F
|
pub ptr: *mut F,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<F: Number> Send for MutMatrix<F> {}
|
unsafe impl<F: Number> Send for MutMatrix<F> {}
|
||||||
@ -239,38 +238,38 @@ impl<F: Number> MutMatrix<F> {
|
|||||||
impl<F: Number> Matrix<F> for MutMatrix<F> {
|
impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.index_mut(row, col),
|
ptr: self.index_mut(row, col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.col_mut(col),
|
ptr: self.col_mut(col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.row_mut(row),
|
ptr: self.row_mut(row),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn stride(&self) -> usize {
|
unsafe fn stride(&self) -> usize {
|
||||||
self.stride
|
self.stride
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> F {
|
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||||
*self.ptr.add(row * self.stride + col)
|
*self.ptr.add(row * self.stride + col)
|
||||||
@ -347,7 +346,7 @@ impl<F: Number> MatrixMut<F> for MutMatrix<F> {
|
|||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct MutTransposedMatrix<F: Number> {
|
pub struct MutTransposedMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *mut F
|
pub ptr: *mut F,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
|
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
|
||||||
@ -362,25 +361,25 @@ impl<F: Number> MutTransposedMatrix<F> {
|
|||||||
impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.index_mut(row, col),
|
ptr: self.index_mut(row, col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.col_mut(col),
|
ptr: self.col_mut(col),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ptr: self.row_mut(row),
|
ptr: self.row_mut(row),
|
||||||
stride: self.stride
|
stride: self.stride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -440,7 +439,6 @@ impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||||
@ -467,5 +465,3 @@ impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
|||||||
self.ptr
|
self.ptr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
14
src/sgemm.rs
14
src/sgemm.rs
@ -1,8 +1,7 @@
|
|||||||
|
|
||||||
use crate::kernel::avx::AvxKernel;
|
|
||||||
use crate::kernel::generic::GenericKernel;
|
|
||||||
use crate::dim::*;
|
use crate::dim::*;
|
||||||
use crate::executor::Executor;
|
use crate::executor::Executor;
|
||||||
|
use crate::kernel::avx::AvxKernel;
|
||||||
|
use crate::kernel::generic::GenericKernel;
|
||||||
|
|
||||||
pub unsafe fn sgemm<E: Executor>(
|
pub unsafe fn sgemm<E: Executor>(
|
||||||
e: &E,
|
e: &E,
|
||||||
@ -22,9 +21,6 @@ pub unsafe fn sgemm<E: Executor>(
|
|||||||
ldc: usize,
|
ldc: usize,
|
||||||
) {
|
) {
|
||||||
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
|
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
|
||||||
e,
|
e, transa, transb, transc, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
transa, transb, transc,
|
);
|
||||||
m, n, k, alpha,
|
}
|
||||||
a, lda, b, ldb,
|
|
||||||
beta, c, ldc);
|
|
||||||
}
|
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
use rand::Rng;
|
|
||||||
|
|
||||||
const M_LEN: usize = 17;
|
const M_LEN: usize = 17;
|
||||||
const N_LEN: usize = 6;
|
const N_LEN: usize = 6;
|
||||||
const K_LEN: usize = 4;
|
const K_LEN: usize = 4;
|
||||||
|
|
||||||
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||||
let mut rng = rand::thread_rng();
|
let _rng = rand::thread_rng();
|
||||||
|
|
||||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
|
||||||
let mut a = vec![0.0; m * k];
|
let mut a = vec![0.0; m * k];
|
||||||
let mut a_t = vec![0.0; m * k];
|
let mut a_t = vec![0.0; m * k];
|
||||||
|
|
||||||
let mut b = vec![0.0; n * k];
|
let mut b = vec![0.0; n * k];
|
||||||
let mut b_t = vec![0.0; n * k];
|
let mut b_t = vec![0.0; n * k];
|
||||||
|
|
||||||
@ -38,7 +36,6 @@ fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
|||||||
(a, a_t, b, b_t)
|
(a, a_t, b, b_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unsafe fn sgemm_ref(
|
unsafe fn sgemm_ref(
|
||||||
m: usize,
|
m: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
@ -55,9 +52,9 @@ unsafe fn sgemm_ref(
|
|||||||
for i_m in 0..m {
|
for i_m in 0..m {
|
||||||
for i_k in 0..k {
|
for i_k in 0..k {
|
||||||
let a_part = *a.add(i_m * lda + i_k);
|
let a_part = *a.add(i_m * lda + i_k);
|
||||||
|
|
||||||
for i_n in 0..n {
|
for i_n in 0..n {
|
||||||
*c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n);
|
*c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,7 +72,6 @@ unsafe fn sgemm_ref(
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sgemm_nn() {
|
fn test_sgemm_nn() {
|
||||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
@ -133,7 +129,6 @@ fn test_sgemm_nn() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for row in 0..N_LEN {
|
for row in 0..N_LEN {
|
||||||
for col in 0..M_LEN {
|
for col in 0..M_LEN {
|
||||||
let index = row * M_LEN + col;
|
let index = row * M_LEN + col;
|
||||||
@ -160,16 +155,24 @@ fn test_sgemm_nn() {
|
|||||||
|
|
||||||
let (a, b) = (c[index], cref[index]);
|
let (a, b) = (c[index], cref[index]);
|
||||||
|
|
||||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
assert!(
|
||||||
|
feq(a, b),
|
||||||
|
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
a,
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
b
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sgemm_nt() {
|
fn test_sgemm_nt() {
|
||||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
let (a, _, b, b_t) = make_matrices();
|
let (a, _, _b, b_t) = make_matrices();
|
||||||
|
|
||||||
let mut c = vec![0.0; m * n];
|
let mut c = vec![0.0; m * n];
|
||||||
let mut cref = vec![0.0; m * n];
|
let mut cref = vec![0.0; m * n];
|
||||||
@ -216,16 +219,24 @@ fn test_sgemm_nt() {
|
|||||||
for col in 0..M_LEN {
|
for col in 0..M_LEN {
|
||||||
let index = row * M_LEN + col;
|
let index = row * M_LEN + col;
|
||||||
let (a, b) = (c[index], cref[index]);
|
let (a, b) = (c[index], cref[index]);
|
||||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
assert!(
|
||||||
|
feq(a, b),
|
||||||
|
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
a,
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
b
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sgemm_tn() {
|
fn test_sgemm_tn() {
|
||||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
let (a, a_t, b, _) = make_matrices();
|
let (_a, a_t, b, _) = make_matrices();
|
||||||
|
|
||||||
let mut c = vec![0.0; m * n];
|
let mut c = vec![0.0; m * n];
|
||||||
let mut cref = vec![0.0; m * n];
|
let mut cref = vec![0.0; m * n];
|
||||||
@ -272,16 +283,24 @@ fn test_sgemm_tn() {
|
|||||||
for col in 0..M_LEN {
|
for col in 0..M_LEN {
|
||||||
let index = row * M_LEN + col;
|
let index = row * M_LEN + col;
|
||||||
let (a, b) = (c[index], cref[index]);
|
let (a, b) = (c[index], cref[index]);
|
||||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
assert!(
|
||||||
|
feq(a, b),
|
||||||
|
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
a,
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
b
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sgemm_tt() {
|
fn test_sgemm_tt() {
|
||||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
let (a, a_t, b, b_t) = make_matrices();
|
let (_a, a_t, _b, b_t) = make_matrices();
|
||||||
|
|
||||||
let mut c = vec![0.0; m * n];
|
let mut c = vec![0.0; m * n];
|
||||||
let mut cref = vec![0.0; m * n];
|
let mut cref = vec![0.0; m * n];
|
||||||
@ -328,7 +347,16 @@ fn test_sgemm_tt() {
|
|||||||
for col in 0..M_LEN {
|
for col in 0..M_LEN {
|
||||||
let index = row * M_LEN + col;
|
let index = row * M_LEN + col;
|
||||||
let (a, b) = (c[index], cref[index]);
|
let (a, b) = (c[index], cref[index]);
|
||||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
assert!(
|
||||||
|
feq(a, b),
|
||||||
|
"a != b, a[{}, {}]={}, b[{}, {}]={}",
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
a,
|
||||||
|
row,
|
||||||
|
col,
|
||||||
|
b
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -341,4 +369,4 @@ fn feq(a: f32, b: f32) -> bool {
|
|||||||
} else {
|
} else {
|
||||||
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1 +1 @@
|
|||||||
mod l3s;
|
mod l3s;
|
||||||
|
Loading…
Reference in New Issue
Block a user