Generalization + SSE

This commit is contained in:
Andrey Tkachenko 2019-12-05 11:16:35 +04:00
parent ffae99bafd
commit aaf8ec3203
27 changed files with 2948 additions and 984 deletions

156
Cargo.lock generated
View File

@ -36,14 +36,66 @@ name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "crossbeam-deque"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-epoch"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-queue"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-utils"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "either"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "gemm"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -56,17 +108,30 @@ dependencies = [
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "hermit-abi"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "libc"
version = "0.2.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "matrixmultiply"
version = "0.2.3"
name = "memoffset"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -86,6 +151,15 @@ dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num_cpus"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "openblas"
version = "0.0.1"
@ -134,9 +208,60 @@ dependencies = [
]
[[package]]
name = "rawpointer"
version = "0.2.1"
name = "rayon"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon-core"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rustc_version"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "scopeguard"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "semver"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "semver-parser"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "threadpool"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasi"
@ -149,16 +274,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700"
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd"
"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791"
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865"
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"

View File

@ -9,6 +9,11 @@ build = "build.rs"
[dependencies]
rand = "0.7.2"
matrixmultiply = "0.2.3"
blas = "0.20.0"
openblas = "0.0.1"
crunchy = "0.2.2"
rayon = "1.2.1"
threadpool = "1.7.1"
[profile.release]
debug = true

View File

@ -13,12 +13,8 @@ impl Alloc {
Alloc { ptr, layout }
}
pub fn ptr_f32(&self) -> *mut f32 {
self.ptr as *mut f32
}
pub fn ptr(&self) -> *mut u8 {
self.ptr
pub fn ptr<F>(&self) -> *mut F {
self.ptr as *mut F
}
}

75
src/dim.rs Normal file
View File

@ -0,0 +1,75 @@
pub trait Dim {
const DIM: usize;
}
pub struct A1;
impl Dim for A1 {
const DIM: usize = 1;
}
pub struct A2;
impl Dim for A2 {
const DIM: usize = 2;
}
pub struct A3;
impl Dim for A3 {
const DIM: usize = 3;
}
pub struct A4;
impl Dim for A4 {
const DIM: usize = 4;
}
pub struct A5;
impl Dim for A5 {
const DIM: usize = 5;
}
pub struct A6;
impl Dim for A6 {
const DIM: usize = 6;
}
pub struct A7;
impl Dim for A7 {
const DIM: usize = 7;
}
pub struct A8;
impl Dim for A8 {
const DIM: usize = 8;
}
pub struct A16;
impl Dim for A16 {
const DIM: usize = 16;
}
/*
macro_rules! gen {
(($name: ident, $val: lit)) => {
};
}
gen! {
(A0, 0),
(A1, 1),
(A2, 2),
(A3, 3),
(A4, 4),
(A5, 5),
(A6, 6),
(A7, 7),
(A8, 8),
(A9, 9),
(A10, 10),
(A11, 11),
(A12, 12),
(A14, 14),
(A16, 16),
(A18, 18),
(A20, 20),
(A22, 22),
(A24, 24),
(A26, 26),
(A28, 28),
(A30, 30),
(A32, 32),
}*/

111
src/executor.rs Normal file
View File

@ -0,0 +1,111 @@
pub trait Executor {
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
&self,
start: usize,
end: usize,
step: usize,
f: F,
);
fn synchronize(&self);
}
pub struct DefaultExecutor;
impl Executor for DefaultExecutor {
#[inline]
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
&self,
start: usize,
end: usize,
step: usize,
f: F,
) {
for i in (start..end).step_by(step) {
f(i);
}
}
fn synchronize(&self) {
}
}
pub struct RayonExecutor;
impl Executor for RayonExecutor {
#[inline]
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
&self,
start: usize,
end: usize,
step: usize,
f: F,
) {
let end = 1 + ((end - 1) / step);
use rayon::prelude::*;
(start..end)
.into_par_iter()
.map(|x| x * step)
.for_each(|x|f(x));
}
fn synchronize(&self) {
}
}
pub struct ThreadPoolExecutor {
thread_pool: threadpool::ThreadPool
}
impl ThreadPoolExecutor {
pub fn new() -> ThreadPoolExecutor {
ThreadPoolExecutor {
thread_pool: threadpool::Builder::new().build()
}
}
}
impl Executor for ThreadPoolExecutor {
#[inline]
fn execute<F: FnOnce(usize) + Send + 'static + Copy>(
&self,
start: usize,
end: usize,
step: usize,
f: F,
) {
let thread_count = self.thread_pool.max_count();
let len = end - start;
let num_steps = len / step;
let mut left_steps = num_steps % thread_count;
let main_steps = num_steps - left_steps;
let job_size = main_steps / thread_count;
let mut prev_end = 0;
for _ in 0..thread_count {
let mut now_end = prev_end + job_size;
if left_steps > 0 {
now_end += 1;
left_steps -= 1;
}
self.thread_pool.execute(move || {
for j in prev_end..now_end {
f(start + j * step);
}
});
prev_end = now_end
}
}
#[inline]
fn synchronize(&self) {
self.thread_pool.join();
}
}

242
src/gemm.rs Normal file
View File

@ -0,0 +1,242 @@
use crate::aligned_alloc;
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
use crate::kernel::params::single::*;
use crate::kernel;
use crate::kernel::GemmKernel;
use crate::dim::Dim;
use crate::executor::Executor;
pub unsafe fn gemm<E, F, K, MR, NR>(
e: &E,
transa: bool,
transb: bool,
transc: bool,
m: usize,
n: usize,
k: usize,
alpha: F,
a: *const F,
lda: usize,
b: *const F,
ldb: usize,
beta: F,
c: *mut F,
ldc: usize,
)
where E: Executor,
F: Number,
MR: Dim, NR: Dim,
K: GemmKernel<F, MR, NR>,
{
match (transa, transb, transc) {
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
e, m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
}
}
unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
e: &E,
m: usize,
n: usize,
k: usize,
alpha: F,
a: A,
b: B,
beta: F,
c: C
)
where E: Executor,
F: Number,
MR: Dim, NR: Dim,
K: GemmKernel<F, MR, NR>,
A: Matrix<F>,
B: Matrix<F>,
C: MatrixMut<F>,
{
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<F>());
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<F>());
for j in (0..n).step_by(NC) {
let j_b = std::cmp::min(n - j, NC);
for p in (0..k).step_by(KC) {
let p_b = std::cmp::min(k - p, KC);
for i in (0..m).step_by(MC) {
let i_b = std::cmp::min(m - i, MC);
let pa = MutMatrix::new(packed_a.ptr::<F>(), p_b);
let pb = MutMatrix::new(packed_b.ptr::<F>(), p_b);
inner_kernel::<E, F, K, MR, NR, _, _, _>(
e,
i_b,
j_b,
p_b,
alpha,
a.sub(p, i),
b.sub(j, p),
beta,
c.sub(j, i),
pa,
pb,
i == 0
);
}
}
}
}
// | MR |
// +----------------------+
// |TL : : : |TR|
// | : : : | |
// + - - - - - - - - - ---+----
// | : : : | | NR
// | : : : | |
// + - - - - - - - - - ---+----
// | : : : | |
// | : : : | |
// +----------------------+
// |BL | | | |BR|
// +----------------------+
unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
e: &E,
m: usize,
n: usize,
k: usize,
alpha: F,
a: A,
b: B,
beta: F,
c: C,
pa: MutMatrix<F>,
pb: MutMatrix<F>,
first_time: bool,
)
where E: Executor,
F: Number,
MR: Dim,
NR: Dim,
K: kernel::GemmKernel<F, MR, NR>,
A: Matrix<F>,
B: Matrix<F>,
C: MatrixMut<F>,
{
let n_left = n % NR;
let n_main = n - n_left;
let m_left = m % MR;
let m_main = m - m_left;
if first_time {
e.execute(0, n_main, NR, move |j|
K::pack_row_b(b, pb, j));
}
e.execute(0, m_main, MR, move |i|
K::pack_row_a(a, pa, i));
e.synchronize();
e.execute(0, n_main, NR, move |j| {
// Section TL
for i in (0..m_main).step_by(MR) {
K::main_tl(alpha,
pa.sub_row(i),
pb.sub_row(j),
beta,
c.sub(j, i));
}
// Section TR
for i in m_main..m {
K::sup_tr(
alpha,
a.sub_col(i),
pb.sub_row(j),
beta,
c.sub(j, i));
}
});
e.execute(n_main, n, 1, move |j| {
// Section BL
for i in (0..m_main).step_by(MR) {
K::sup_bl(
alpha,
pa.sub_row(i),
b.sub_row(j),
beta,
c.sub(j, i)
);
}
// Section BR
for i in m_main..m {
K::sup_br(
k,
alpha,
a.sub_col(i),
b.sub_row(j),
beta,
c.sub(j, i))
}
});
e.synchronize();
}

View File

@ -246,6 +246,7 @@ pub unsafe fn sdot(
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);

View File

@ -1,29 +1,82 @@
use super::fma::fmadd_ps;
use super::intrinsics::*;
use crate::kernel::params::single::{MR, NR};
use crate::matrix::{Matrix, MutMatrix, MatrixMut};
use crunchy::unroll;
pub(crate) unsafe fn sgemm_ukr_16x4(
#[inline]
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
k: usize,
alpha: f32,
pa: *const f32,
pb: *const f32,
pa: MutMatrix<f32>,
pb: MutMatrix<f32>,
beta: f32,
c: *mut f32,
ldc: usize,
c: C,
) {
let mut mt00 = _mm256_setzero_ps();
let mut mt01 = _mm256_setzero_ps();
let mut mt02 = _mm256_setzero_ps();
let mut mt03 = _mm256_setzero_ps();
let mut mt04 = _mm256_setzero_ps();
// let mut mt05 = _mm256_setzero_ps();
// let mut mt06 = _mm256_setzero_ps();
// let mut mt07 = _mm256_setzero_ps();
let mut mt10 = _mm256_setzero_ps();
let mut mt11 = _mm256_setzero_ps();
let mut mt12 = _mm256_setzero_ps();
let mut mt13 = _mm256_setzero_ps();
let mut mt14 = _mm256_setzero_ps();
// let mut mt15 = _mm256_setzero_ps();
// let mut mt16 = _mm256_setzero_ps();
// let mut mt17 = _mm256_setzero_ps();
let mut pa = pa;
let mut pb = pb;
let mut pa = pa.ptr();
let mut pb = pb.ptr();
for _ in 0..k {
const BATCH: usize = 16;
let k_right = k % BATCH;
let k_main = k - k_right;
for _ in (0..k_main).step_by(BATCH) {
unroll! {
for i in 0..16 {
let a0 = _mm256_load_ps(pa.add(i * MR));
let a1 = _mm256_load_ps(pa.add(i * MR + 8));
let b0 = _mm256_broadcast_ss(&*pb.add(i * NR));
let b1 = _mm256_broadcast_ss(&*pb.add(i * NR + 1));
let b2 = _mm256_broadcast_ss(&*pb.add(i * NR + 2));
let b3 = _mm256_broadcast_ss(&*pb.add(i * NR + 3));
let b4 = _mm256_broadcast_ss(&*pb.add(i * NR + 4));
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
mt00 = fmadd_ps(a0, b0, mt00);
mt01 = fmadd_ps(a0, b1, mt01);
mt02 = fmadd_ps(a0, b2, mt02);
mt03 = fmadd_ps(a0, b3, mt03);
mt04 = fmadd_ps(a0, b4, mt04);
// mt05 = fmadd_ps(a0, b5, mt05);
// mt06 = fmadd_ps(a0, b6, mt06);
// mt07 = fmadd_ps(a0, b7, mt07);
mt10 = fmadd_ps(a1, b0, mt10);
mt11 = fmadd_ps(a1, b1, mt11);
mt12 = fmadd_ps(a1, b2, mt12);
mt13 = fmadd_ps(a1, b3, mt13);
mt14 = fmadd_ps(a1, b4, mt14);
}
}
pa = pa.add(BATCH * MR);
pb = pb.add(BATCH * NR);
}
for _ in k_main..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
@ -31,66 +84,107 @@ pub(crate) unsafe fn sgemm_ukr_16x4(
let b1 = _mm256_broadcast_ss(&*pb.add(1));
let b2 = _mm256_broadcast_ss(&*pb.add(2));
let b3 = _mm256_broadcast_ss(&*pb.add(3));
let b4 = _mm256_broadcast_ss(&*pb.add(4));
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
mt00 = fmadd_ps(a0, b0, mt00);
mt01 = fmadd_ps(a0, b1, mt01);
mt02 = fmadd_ps(a0, b2, mt02);
mt03 = fmadd_ps(a0, b3, mt03);
mt04 = fmadd_ps(a0, b4, mt04);
// mt05 = fmadd_ps(a0, b5, mt05);
// mt06 = fmadd_ps(a0, b6, mt06);
// mt07 = fmadd_ps(a0, b7, mt07);
mt10 = fmadd_ps(a1, b0, mt10);
mt11 = fmadd_ps(a1, b1, mt11);
mt12 = fmadd_ps(a1, b2, mt12);
mt13 = fmadd_ps(a1, b3, mt13);
mt14 = fmadd_ps(a1, b4, mt14);
pa = pa.add(16);
pb = pb.add(4);
pa = pa.add(MR);
pb = pb.add(NR);
}
let alpha = _mm256_broadcast_ss(&alpha);
// let alpha = _mm256_broadcast_ss(&alpha);
mt00 = _mm256_mul_ps(alpha, mt00);
mt01 = _mm256_mul_ps(alpha, mt01);
mt02 = _mm256_mul_ps(alpha, mt02);
mt03 = _mm256_mul_ps(alpha, mt03);
mt10 = _mm256_mul_ps(alpha, mt10);
mt11 = _mm256_mul_ps(alpha, mt11);
mt12 = _mm256_mul_ps(alpha, mt12);
mt13 = _mm256_mul_ps(alpha, mt13);
// mt00 = _mm256_mul_ps(alpha, mt00);
// mt01 = _mm256_mul_ps(alpha, mt01);
// mt02 = _mm256_mul_ps(alpha, mt02);
// mt03 = _mm256_mul_ps(alpha, mt03);
// mt04 = _mm256_mul_ps(alpha, mt04);
// // mt05 = _mm256_mul_ps(alpha, mt05);
// // mt06 = _mm256_mul_ps(alpha, mt06);
// // mt07 = _mm256_mul_ps(alpha, mt07);
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
// mt10 = _mm256_mul_ps(alpha, mt10);
// mt11 = _mm256_mul_ps(alpha, mt11);
// mt12 = _mm256_mul_ps(alpha, mt12);
// mt13 = _mm256_mul_ps(alpha, mt13);
// mt14 = _mm256_mul_ps(alpha, mt14);
// // mt15 = _mm256_mul_ps(alpha, mt15);
// // mt16 = _mm256_mul_ps(alpha, mt16);
// // mt17 = _mm256_mul_ps(alpha, mt17);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
let ccol0 = c.ptr_mut();
let ccol1 = c.row_mut(1);
let ccol2 = c.row_mut(2);
let ccol3 = c.row_mut(3);
let ccol4 = c.row_mut(4);
// let ccol5 = c.row_mut(5);
// let ccol6 = c.row_mut(6);
// let ccol7 = c.add(ldc * 7);
mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
}
// if beta != 0.0 {
// let beta = _mm256_broadcast_ss(&beta);
// mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
// mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
// mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
// mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
// mt04 = fmadd_ps(beta, _mm256_loadu_ps(ccol4), mt04);
// // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05);
// // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06);
// // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07);
// mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
// mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
// mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
// mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
// mt14 = fmadd_ps(beta, _mm256_loadu_ps(ccol4.add(8)), mt14);
// // mt15 = fmadd_ps(beta, _mm256_loadu_ps(ccol5.add(8)), mt15);
// // mt16 = fmadd_ps(beta, _mm256_loadu_ps(ccol6.add(8)), mt16);
// // mt17 = fmadd_ps(beta, _mm256_loadu_ps(ccol7.add(8)), mt17);
// }
_mm256_storeu_ps(ccol0, mt00);
_mm256_storeu_ps(ccol1, mt01);
_mm256_storeu_ps(ccol2, mt02);
_mm256_storeu_ps(ccol3, mt03);
_mm256_storeu_ps(ccol4, mt04);
// _mm256_storeu_ps(ccol5, mt05);
// _mm256_storeu_ps(ccol6, mt06);
// _mm256_storeu_ps(ccol7, mt07);
_mm256_storeu_ps(ccol0.add(8), mt10);
_mm256_storeu_ps(ccol1.add(8), mt11);
_mm256_storeu_ps(ccol2.add(8), mt12);
_mm256_storeu_ps(ccol3.add(8), mt13);
_mm256_storeu_ps(ccol4.add(8), mt14);
// _mm256_storeu_ps(ccol5.add(8), mt15);
// _mm256_storeu_ps(ccol6.add(8), mt16);
// _mm256_storeu_ps(ccol7.add(8), mt17);
}
pub(crate) unsafe fn sgemm_sup_16x1(
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
pa: *const f32,
b: *const f32,
pa: MutMatrix<f32>,
b: B,
beta: f32,
c: *mut f32,
c: C,
) {
let mut mt0 = _mm256_setzero_ps();
let mut mt1 = _mm256_setzero_ps();
@ -99,76 +193,32 @@ pub(crate) unsafe fn sgemm_sup_16x1(
let mut b = b;
for _ in 0..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
let a0 = _mm256_load_ps(pa.ptr());
let a1 = _mm256_load_ps(pa.col(8));
let b0 = _mm256_broadcast_ss(&*b);
let b0 = _mm256_broadcast_ss(&*b.ptr());
mt0 = fmadd_ps(a0, b0, mt0);
mt1 = fmadd_ps(a1, b0, mt1);
pa = pa.add(16);
b = b.add(1);
pa.shift_col(16);
b.inc_col();
}
let alpha = _mm256_broadcast_ss(&alpha);
// let alpha = _mm256_broadcast_ss(&alpha);
mt0 = _mm256_mul_ps(alpha, mt0);
mt1 = _mm256_mul_ps(alpha, mt1);
// mt0 = _mm256_mul_ps(alpha, mt0);
// mt1 = _mm256_mul_ps(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
// if beta != 0.0 {
// let beta = _mm256_broadcast_ss(&beta);
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
}
// mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
// mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
// }
_mm256_storeu_ps(c, mt0);
_mm256_storeu_ps(c.add(8), mt1);
}
pub(crate) unsafe fn sgemm_sup0_t(
k: usize,
alpha: f32,
pa: *const f32,
b: *const f32,
ldb: usize,
beta: f32,
c: *mut f32,
) {
let mut mt0 = _mm256_setzero_ps();
let mut mt1 = _mm256_setzero_ps();
let mut pa = pa;
let mut b = b;
for _ in 0..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
let b0 = _mm256_broadcast_ss(&*b);
mt0 = fmadd_ps(a0, b0, mt0);
mt1 = fmadd_ps(a1, b0, mt1);
pa = pa.add(16);
b = b.add(ldb);
}
let alpha = _mm256_broadcast_ss(&alpha);
mt0 = _mm256_mul_ps(alpha, mt0);
mt1 = _mm256_mul_ps(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
}
_mm256_storeu_ps(c, mt0);
_mm256_storeu_ps(c.add(8), mt1);
_mm256_storeu_ps(c.ptr_mut(), mt0);
_mm256_storeu_ps(c.col_mut(8), mt1);
}
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {

View File

@ -5,3 +5,85 @@ pub mod l1d;
pub mod l1s;
pub mod l3d;
pub mod l3s;
use core::marker::PhantomData;
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*;
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
where I: GemmKernelSupNr<f32, A5>
{
#[inline]
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
a: A,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
I::sup_tr(alpha, a, pb, beta, c);
}
}
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
where I: GemmKernelSupMr<f32, A16>
{
#[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
b: B,
beta: f32,
c: C,
) {
self::l3s::sgemm_sup_16x1(pa.stride, alpha, pa, b, beta, c);
}
}
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
where I: GemmKernelSup<f32>
{
#[inline]
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C,
) {
I::sup_br(k, alpha, a, b, beta, c);
}
}
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
where I: GemmKernel<f32, A16, A5>
{
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
if a.is_transposed() {
I::pack_row_a(a, pa, i);
} else {
self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
}
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
I::pack_row_b(b, pb, j);
}
#[inline]
unsafe fn main_tl<C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c);
}
}

View File

@ -1,114 +1,91 @@
pub(crate) unsafe fn sgemm_sup_1x4(
use crate::kernel::params::single::NR;
use crate::matrix::{MutMatrix, Matrix, MatrixMut};
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
pb: *const f32,
a: A,
pb: MutMatrix<f32>,
beta: f32,
c: *mut f32,
ldc: usize,
c: C,
) {
let mut c0 = 0.0;
let mut c1 = 0.0;
let mut c2 = 0.0;
let mut c3 = 0.0;
let mut c0 = 0.0f32;
let mut c1 = 0.0f32;
let mut c2 = 0.0f32;
let mut c3 = 0.0f32;
let mut c4 = 0.0f32;
// let mut c5 = 0.0f32;
// let mut c6 = 0.0f32;
// let mut c7 = 0.0f32;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
let a0 = *a.ptr();
c0 += *pb * a0;
c1 += *pb.add(1) * a0;
c2 += *pb.add(2) * a0;
c3 += *pb.add(3) * a0;
c0 += *pb.ptr() * a0;
c1 += *pb.col(1) * a0;
c2 += *pb.col(2) * a0;
c3 += *pb.col(3) * a0;
c4 += *pb.col(4) * a0;
// c5 += *pb.col(5) * a0;
// c6 += *pb.col(6) * a0;
// c7 += *pb.col(7) * a0;
a = a.add(lda);
pb = pb.add(4);
a.inc_row();
pb.shift_col(NR);
}
c0 *= alpha;
c1 *= alpha;
c2 *= alpha;
c3 *= alpha;
// c0 *= alpha;
// c1 *= alpha;
// c2 *= alpha;
// c3 *= alpha;
// c4 *= alpha;
// // c5 *= alpha;
// // c6 *= alpha;
// // c7 *= alpha;
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
let ccol0 = c.ptr_mut();
let ccol1 = c.row_mut(1);
let ccol2 = c.row_mut(2);
let ccol3 = c.row_mut(3);
let ccol4 = c.row_mut(4);
// let ccol5 = c.row_mut(5);
// let ccol6 = c.row_mut(6);
// let ccol7 = c.add(ldc * 7);
if beta != 0.0 {
c0 += beta * *ccol0;
c1 += beta * *ccol1;
c2 += beta * *ccol2;
c3 += beta * *ccol3;
}
// if beta != 0.0 {
// c0 += beta * *ccol0;
// c1 += beta * *ccol1;
// c2 += beta * *ccol2;
// c3 += beta * *ccol3;
// c4 += beta * *ccol4;
// // c5 += beta * *ccol5;
// // c6 += beta * *ccol6;
// // c7 += beta * *ccol7;
// }
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
*ccol4 = c4;
// *ccol5 = c5;
// *ccol6 = c6;
// *ccol7 = c7;
}
pub(crate) unsafe fn sgemm_sup1_t(
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
pb: *const f32,
beta: f32,
c: *mut f32,
ldc: usize,
) {
let mut c0 = 0.0;
let mut c1 = 0.0;
let mut c2 = 0.0;
let mut c3 = 0.0;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
c0 += *pb * a0;
c1 += *pb.add(1) * a0;
c2 += *pb.add(2) * a0;
c3 += *pb.add(3) * a0;
a = a.add(1);
pb = pb.add(4);
}
c0 *= alpha;
c1 *= alpha;
c2 *= alpha;
c3 *= alpha;
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
c0 += beta * *ccol0;
c1 += beta * *ccol1;
c2 += beta * *ccol2;
c3 += beta * *ccol3;
}
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
}
pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
let mut bcol0 = b;
let mut bcol1 = b.add(ldb);
let mut bcol2 = b.add(ldb * 2);
let mut bcol3 = b.add(ldb * 3);
let mut bcol4 = b.add(ldb * 4);
// let mut bcol5 = b.add(ldb * 5);
// let mut bcol6 = b.add(ldb * 6);
// let mut bcol7 = b.row(7);
let mut pb = pb;
@ -117,27 +94,24 @@ pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f
*pb.add(1) = *bcol1;
*pb.add(2) = *bcol2;
*pb.add(3) = *bcol3;
*pb.add(4) = *bcol4;
// *pb.add(5) = *bcol5;
// *pb.add(6) = *bcol6;
// *pb.col(7) = *bcol7;
bcol0 = bcol0.add(1);
bcol1 = bcol1.add(1);
bcol2 = bcol2.add(1);
bcol3 = bcol3.add(1);
pb = pb.add(4);
bcol4 = bcol4.add(1);
// bcol5 = bcol5.add(1);
// bcol6 = bcol6.add(1);
// bcol7 = bcol7.add(1);
pb = pb.add(NR);
}
}
// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
// use crate::kernel::params::single::MR;
// let mut a = a;
// let mut pa = pa;
// for p in 0..k {
// for j in 0..MR {
// *pa.add(j * k + p) = *a.add(p * lda + j);
// }
// }
// }
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
use crate::kernel::params::single::NR;
let mut b = b;

View File

@ -1,2 +1,98 @@
pub mod l3d;
pub mod l3s;
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*;
pub struct GenericKernel;
impl GemmKernelSupNr<f32, A5> for GenericKernel {
#[inline]
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
a: A,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
self::l3s::sgemm_sup_1x8(
pb.stride,
alpha,
a,
pb,
beta,
c)
}
}
impl GemmKernelSupMr<f32, A16> for GenericKernel {
#[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
b: B,
beta: f32,
c: C,
) {
unimplemented!()
}
}
impl GemmKernelSup<f32> for GenericKernel {
#[inline]
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C,
) {
let mut elem = 0.0;
for p in 0..k {
elem += *a.row(p) * *b.col(p);
}
elem *= alpha;
if beta != 0.0 {
elem += beta * *c.ptr();
}
*c.ptr_mut() = elem;
}
}
impl GemmKernel<f32, A16, A5> for GenericKernel {
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
if a.is_transposed() {
self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
} else {
unimplemented!()
}
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
if b.is_transposed() {
self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
} else {
self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
}
}
#[inline]
unsafe fn main_tl<C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
unimplemented!()
}
}

View File

@ -2,53 +2,20 @@
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
mod avx;
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub use avx::{l1d::*, l1s::*};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use avx::l3s::{
sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t,
};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use avx::l3d::{
dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr,
};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1};
mod generic;
pub mod avx;
pub mod generic;
pub mod sse;
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
use crate::dim::Dim;
pub mod params {
pub mod single {
pub const MC: usize = 256;
pub const KC: usize = 128;
pub const NC: usize = 1024;
pub const MC: usize = 128;
pub const KC: usize = 256;
pub const NC: usize = 8*1024;
pub const MR: usize = 16;
pub const NR: usize = 4;
pub const NR: usize = 5;
}
pub mod double {
pub const MC: usize = 256;
@ -58,3 +25,60 @@ pub mod params {
pub const NR: usize = 4;
}
}
// +----------------------+
// |TL : : : |TR|
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// +----------------------+
// |BL | | | |BR|
// +----------------------+
pub trait GemmKernelSup<F: Number> {
unsafe fn sup_br<A: Matrix<F>, B: Matrix<F>, C: MatrixMut<F>>(
k: usize,
alpha: F,
a: A,
b: B,
beta: F,
c: C,
);
}
pub trait GemmKernelSupMr<F: Number, MR: Dim> {
unsafe fn sup_bl<B: Matrix<F>, C: MatrixMut<F>>(
alpha: F,
pa: MutMatrix<F>,
b: B,
beta: F,
c: C,
);
}
pub trait GemmKernelSupNr<F: Number, NR: Dim> {
unsafe fn sup_tr<A: Matrix<F>, C: MatrixMut<F>>(
alpha: F,
a: A,
pb: MutMatrix<F>,
beta: F,
c: C,
);
}
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>, i: usize);
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>, j: usize);
unsafe fn main_tl<C: MatrixMut<F>>(
alpha: F,
pa: MutMatrix<F>,
pb: MutMatrix<F>,
beta: F,
c: C,
);
}

47
src/kernel/sse/fma.rs Normal file
View File

@ -0,0 +1,47 @@
use super::intrinsics::*;
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
_mm_fmadd_ps(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
#[inline(always)]
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
_mm_add_ps(_mm_mul_ps(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
_mm_fmsub_ps(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
_mm_sub_ps(_mm_mul_ps(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
_mm_fmadd_pd(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
#[inline(always)]
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
_mm_add_pd(_mm_mul_pd(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
_mm_fmsub_pd(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
_mm_sub_pd(_mm_mul_pd(a, b), c)
}

23
src/kernel/sse/hsum.rs Normal file
View File

@ -0,0 +1,23 @@
use super::intrinsics::*;
#[inline(always)]
pub unsafe fn hsum_ps(v: __m256) -> f32 {
let qhigh = _mm256_extractf128_ps(v, 1);
let qlow = _mm256_castps256_ps128(v);
let qsum = _mm_add_ps(qhigh, qlow);
let dhigh = _mm_movehl_ps(qsum, qsum);
let dlow = qsum;
let dsum = _mm_add_ps(dhigh, dlow);
let high = _mm_shuffle_ps(dsum, dsum, 1);
let low = dsum;
_mm_cvtss_f32(_mm_add_ss(high, low))
}
#[inline(always)]
pub unsafe fn hsum_pd(v: __m256d) -> f64 {
let vhigh = _mm256_extractf128_pd(v, 1);
let vlow = _mm256_castpd256_pd128(v);
let vsum = _mm_add_pd(vlow, vhigh);
let h64 = _mm_unpackhi_pd(vsum, vsum);
_mm_cvtsd_f64(_mm_add_sd(vsum, h64))
}

View File

@ -0,0 +1,5 @@
#[cfg(target_arch = "x86")]
pub use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
pub use std::arch::x86_64::*;

396
src/kernel/sse/l1d.rs Normal file
View File

@ -0,0 +1,396 @@
use super::fma::{fmadd_pd, fmsub_pd};
use super::hsum::hsum_pd;
use super::intrinsics::*;
pub unsafe fn drot(
n: usize,
mut x: *mut f64,
incx: usize,
mut y: *mut f64,
incy: usize,
c: f64,
s: f64,
) {
if incx == 1 && incy == 1 {
let c0 = _mm256_broadcast_sd(&c);
let s0 = _mm256_broadcast_sd(&s);
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0)));
_mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0)));
_mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1)));
_mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1)));
_mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2)));
_mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2)));
_mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3)));
_mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3)));
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(x, y0);
_mm256_storeu_pd(y, x0);
_mm256_storeu_pd(x.add(4), y1);
_mm256_storeu_pd(y.add(4), x1);
_mm256_storeu_pd(x.add(8), y2);
_mm256_storeu_pd(y.add(8), x2);
_mm256_storeu_pd(x.add(12), y3);
_mm256_storeu_pd(y.add(12), x3);
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) {
if incx == 1 {
let a0 = _mm256_broadcast_sd(&a);
for _ in 0..n / 32 {
let mut x0 = _mm256_loadu_pd(x);
let mut x1 = _mm256_loadu_pd(x.add(4));
let mut x2 = _mm256_loadu_pd(x.add(8));
let mut x3 = _mm256_loadu_pd(x.add(12));
let mut x4 = _mm256_loadu_pd(x.add(16));
let mut x5 = _mm256_loadu_pd(x.add(20));
let mut x6 = _mm256_loadu_pd(x.add(24));
let mut x7 = _mm256_loadu_pd(x.add(28));
x0 = _mm256_mul_pd(a0, x0);
x1 = _mm256_mul_pd(a0, x1);
x2 = _mm256_mul_pd(a0, x2);
x3 = _mm256_mul_pd(a0, x3);
x4 = _mm256_mul_pd(a0, x4);
x5 = _mm256_mul_pd(a0, x5);
x6 = _mm256_mul_pd(a0, x6);
x7 = _mm256_mul_pd(a0, x7);
_mm256_storeu_pd(x, x0);
_mm256_storeu_pd(x.add(4), x1);
_mm256_storeu_pd(x.add(8), x2);
_mm256_storeu_pd(x.add(12), x3);
_mm256_storeu_pd(x.add(16), x4);
_mm256_storeu_pd(x.add(20), x5);
_mm256_storeu_pd(x.add(24), x6);
_mm256_storeu_pd(x.add(28), x7);
x = x.add(32);
}
for _ in 0..n % 32 {
*x *= a;
x = x.add(1);
}
} else {
for _ in 0..n {
*x *= a;
x = x.add(incx);
}
}
}
pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 32 {
let x0 = _mm256_loadu_pd(x);
let x1 = _mm256_loadu_pd(x.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let x4 = _mm256_loadu_pd(x.add(16));
let x5 = _mm256_loadu_pd(x.add(20));
let x6 = _mm256_loadu_pd(x.add(24));
let x7 = _mm256_loadu_pd(x.add(28));
_mm256_storeu_pd(y, x0);
_mm256_storeu_pd(y.add(4), x1);
_mm256_storeu_pd(y.add(8), x2);
_mm256_storeu_pd(y.add(12), x3);
_mm256_storeu_pd(y.add(16), x4);
_mm256_storeu_pd(y.add(20), x5);
_mm256_storeu_pd(y.add(24), x6);
_mm256_storeu_pd(y.add(28), x7);
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
*y = *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y = *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn daxpy(
n: usize,
a: f64,
mut x: *const f64,
incx: usize,
mut y: *mut f64,
incy: usize,
) {
if incx == 1 && incy == 1 {
let a0 = _mm256_broadcast_sd(&a);
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(y, fmadd_pd(a0, x0, y0));
_mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1));
_mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2));
_mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3));
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
*y += a * *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y += a * *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn ddot(
n: usize,
mut x: *const f64,
incx: usize,
mut y: *const f64,
incy: usize,
) -> f64 {
if incx == 1 && incy == 1 {
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
acc0 = fmadd_pd(x0, y0, acc0);
acc1 = fmadd_pd(x1, y1, acc1);
acc2 = fmadd_pd(x2, y2, acc2);
acc3 = fmadd_pd(x3, y3, acc3);
x = x.add(16);
y = y.add(16);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc0 = _mm256_add_pd(acc0, acc2);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 16 {
acc += *x * *y;
x = x.add(1);
y = y.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += *x * *y;
x = x.add(incx);
y = y.add(incy);
}
acc
}
}
pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 {
if incx == 1 {
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let x1 = _mm256_loadu_pd(x.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
acc0 = fmadd_pd(x0, x0, acc0);
acc1 = fmadd_pd(x1, x1, acc1);
acc2 = fmadd_pd(x2, x2, acc2);
acc3 = fmadd_pd(x3, x3, acc3);
x = x.add(16);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc0 = _mm256_add_pd(acc0, acc2);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 16 {
let x0 = *x;
acc += x0 * x0;
x = x.add(1);
}
acc.sqrt()
} else {
let mut acc = 0.0;
for _ in 0..n {
let x0 = *x;
acc += x0 * x0;
x = x.add(incx);
}
acc.sqrt()
}
}
pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 {
if incx == 1 {
let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF));
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
let mut acc4 = _mm256_setzero_pd();
let mut acc5 = _mm256_setzero_pd();
let mut acc6 = _mm256_setzero_pd();
let mut acc7 = _mm256_setzero_pd();
for _ in 0..n / 32 {
let mut x0 = _mm256_loadu_pd(x);
let mut x1 = _mm256_loadu_pd(x.add(4));
let mut x2 = _mm256_loadu_pd(x.add(8));
let mut x3 = _mm256_loadu_pd(x.add(12));
let mut x4 = _mm256_loadu_pd(x.add(16));
let mut x5 = _mm256_loadu_pd(x.add(20));
let mut x6 = _mm256_loadu_pd(x.add(24));
let mut x7 = _mm256_loadu_pd(x.add(28));
x0 = _mm256_and_pd(mask, x0);
x1 = _mm256_and_pd(mask, x1);
x2 = _mm256_and_pd(mask, x2);
x3 = _mm256_and_pd(mask, x3);
x4 = _mm256_and_pd(mask, x4);
x5 = _mm256_and_pd(mask, x5);
x6 = _mm256_and_pd(mask, x6);
x7 = _mm256_and_pd(mask, x7);
acc0 = _mm256_add_pd(acc0, x0);
acc1 = _mm256_add_pd(acc1, x1);
acc2 = _mm256_add_pd(acc2, x2);
acc3 = _mm256_add_pd(acc3, x3);
acc4 = _mm256_add_pd(acc4, x4);
acc5 = _mm256_add_pd(acc5, x5);
acc6 = _mm256_add_pd(acc6, x6);
acc7 = _mm256_add_pd(acc7, x7);
x = x.add(32);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc4 = _mm256_add_pd(acc4, acc5);
acc6 = _mm256_add_pd(acc6, acc7);
acc0 = _mm256_add_pd(acc0, acc2);
acc4 = _mm256_add_pd(acc4, acc6);
acc0 = _mm256_add_pd(acc0, acc4);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 32 {
acc += (*x).abs();
x = x.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += (*x).abs();
x = x.add(incx);
}
acc
}
}

396
src/kernel/sse/l1s.rs Normal file
View File

@ -0,0 +1,396 @@
use super::fma::{fmadd_ps, fmsub_ps};
use super::hsum::hsum_ps;
use super::intrinsics::*;
pub unsafe fn srot(
n: usize,
mut x: *mut f32,
incx: usize,
mut y: *mut f32,
incy: usize,
c: f32,
s: f32,
) {
if incx == 1 && incy == 1 {
let c0 = _mm256_broadcast_ss(&c);
let s0 = _mm256_broadcast_ss(&s);
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0)));
_mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0)));
_mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1)));
_mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1)));
_mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2)));
_mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2)));
_mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3)));
_mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3)));
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(x, y0);
_mm256_storeu_ps(y, x0);
_mm256_storeu_ps(x.add(8), y1);
_mm256_storeu_ps(y.add(8), x1);
_mm256_storeu_ps(x.add(16), y2);
_mm256_storeu_ps(y.add(16), x2);
_mm256_storeu_ps(x.add(24), y3);
_mm256_storeu_ps(y.add(24), x3);
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) {
if incx == 1 {
let a0 = _mm256_broadcast_ss(&a);
for _ in 0..n / 64 {
let mut x0 = _mm256_loadu_ps(x);
let mut x1 = _mm256_loadu_ps(x.add(8));
let mut x2 = _mm256_loadu_ps(x.add(16));
let mut x3 = _mm256_loadu_ps(x.add(24));
let mut x4 = _mm256_loadu_ps(x.add(32));
let mut x5 = _mm256_loadu_ps(x.add(40));
let mut x6 = _mm256_loadu_ps(x.add(48));
let mut x7 = _mm256_loadu_ps(x.add(56));
x0 = _mm256_mul_ps(a0, x0);
x1 = _mm256_mul_ps(a0, x1);
x2 = _mm256_mul_ps(a0, x2);
x3 = _mm256_mul_ps(a0, x3);
x4 = _mm256_mul_ps(a0, x4);
x5 = _mm256_mul_ps(a0, x5);
x6 = _mm256_mul_ps(a0, x6);
x7 = _mm256_mul_ps(a0, x7);
_mm256_storeu_ps(x, x0);
_mm256_storeu_ps(x.add(8), x1);
_mm256_storeu_ps(x.add(16), x2);
_mm256_storeu_ps(x.add(24), x3);
_mm256_storeu_ps(x.add(32), x4);
_mm256_storeu_ps(x.add(40), x5);
_mm256_storeu_ps(x.add(48), x6);
_mm256_storeu_ps(x.add(56), x7);
x = x.add(64);
}
for _ in 0..n % 64 {
*x *= a;
x = x.add(1);
}
} else {
for _ in 0..n {
*x *= a;
x = x.add(incx);
}
}
}
pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 64 {
let x0 = _mm256_loadu_ps(x);
let x1 = _mm256_loadu_ps(x.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let x4 = _mm256_loadu_ps(x.add(32));
let x5 = _mm256_loadu_ps(x.add(40));
let x6 = _mm256_loadu_ps(x.add(48));
let x7 = _mm256_loadu_ps(x.add(56));
_mm256_storeu_ps(y, x0);
_mm256_storeu_ps(y.add(8), x1);
_mm256_storeu_ps(y.add(16), x2);
_mm256_storeu_ps(y.add(24), x3);
_mm256_storeu_ps(y.add(32), x4);
_mm256_storeu_ps(y.add(40), x5);
_mm256_storeu_ps(y.add(48), x6);
_mm256_storeu_ps(y.add(56), x7);
x = x.add(64);
y = y.add(64);
}
for _ in 0..n % 64 {
*y = *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y = *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn saxpy(
n: usize,
a: f32,
mut x: *const f32,
incx: usize,
mut y: *mut f32,
incy: usize,
) {
if incx == 1 && incy == 1 {
let a0 = _mm256_broadcast_ss(&a);
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(y, fmadd_ps(a0, x0, y0));
_mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1));
_mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2));
_mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3));
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
*y += a * *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y += a * *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sdot(
n: usize,
mut x: *const f32,
incx: usize,
mut y: *const f32,
incy: usize,
) -> f32 {
if incx == 1 && incy == 1 {
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
acc0 = fmadd_ps(x0, y0, acc0);
acc1 = fmadd_ps(x1, y1, acc1);
acc2 = fmadd_ps(x2, y2, acc2);
acc3 = fmadd_ps(x3, y3, acc3);
x = x.add(32);
y = y.add(32);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc0 = _mm256_add_ps(acc0, acc2);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 32 {
acc += *x * *y;
x = x.add(1);
y = y.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += *x * *y;
x = x.add(incx);
y = y.add(incy);
}
acc
}
}
pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 {
if incx == 1 {
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let x1 = _mm256_loadu_ps(x.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
acc0 = fmadd_ps(x0, x0, acc0);
acc1 = fmadd_ps(x1, x1, acc1);
acc2 = fmadd_ps(x2, x2, acc2);
acc3 = fmadd_ps(x3, x3, acc3);
x = x.add(32);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc0 = _mm256_add_ps(acc0, acc2);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 32 {
let x0 = *x;
acc += x0 * x0;
x = x.add(1);
}
acc.sqrt()
} else {
let mut acc = 0.0;
for _ in 0..n {
let x0 = *x;
acc += x0 * x0;
x = x.add(incx);
}
acc.sqrt()
}
}
pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 {
if incx == 1 {
let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF));
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut acc4 = _mm256_setzero_ps();
let mut acc5 = _mm256_setzero_ps();
let mut acc6 = _mm256_setzero_ps();
let mut acc7 = _mm256_setzero_ps();
for _ in 0..n / 64 {
let mut x0 = _mm256_loadu_ps(x);
let mut x1 = _mm256_loadu_ps(x.add(8));
let mut x2 = _mm256_loadu_ps(x.add(16));
let mut x3 = _mm256_loadu_ps(x.add(24));
let mut x4 = _mm256_loadu_ps(x.add(32));
let mut x5 = _mm256_loadu_ps(x.add(40));
let mut x6 = _mm256_loadu_ps(x.add(48));
let mut x7 = _mm256_loadu_ps(x.add(56));
x0 = _mm256_and_ps(mask, x0);
x1 = _mm256_and_ps(mask, x1);
x2 = _mm256_and_ps(mask, x2);
x3 = _mm256_and_ps(mask, x3);
x4 = _mm256_and_ps(mask, x4);
x5 = _mm256_and_ps(mask, x5);
x6 = _mm256_and_ps(mask, x6);
x7 = _mm256_and_ps(mask, x7);
acc0 = _mm256_add_ps(acc0, x0);
acc1 = _mm256_add_ps(acc1, x1);
acc2 = _mm256_add_ps(acc2, x2);
acc3 = _mm256_add_ps(acc3, x3);
acc4 = _mm256_add_ps(acc4, x4);
acc5 = _mm256_add_ps(acc5, x5);
acc6 = _mm256_add_ps(acc6, x6);
acc7 = _mm256_add_ps(acc7, x7);
x = x.add(64);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc4 = _mm256_add_ps(acc4, acc5);
acc6 = _mm256_add_ps(acc6, acc7);
acc0 = _mm256_add_ps(acc0, acc2);
acc4 = _mm256_add_ps(acc4, acc6);
acc0 = _mm256_add_ps(acc0, acc4);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 64 {
acc += (*x).abs();
x = x.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += (*x).abs();
x = x.add(incx);
}
acc
}
}

140
src/kernel/sse/l3d.rs Normal file
View File

@ -0,0 +1,140 @@
use super::fma::fmadd_pd;
use super::intrinsics::*;
pub(crate) unsafe fn dgemm_ukr_8x4(
k: usize,
alpha: f64,
pa: *const f64,
pb: *const f64,
beta: f64,
c: *mut f64,
ldc: usize,
) {
let mut mt00 = _mm256_setzero_pd();
let mut mt01 = _mm256_setzero_pd();
let mut mt02 = _mm256_setzero_pd();
let mut mt03 = _mm256_setzero_pd();
let mut mt10 = _mm256_setzero_pd();
let mut mt11 = _mm256_setzero_pd();
let mut mt12 = _mm256_setzero_pd();
let mut mt13 = _mm256_setzero_pd();
let mut pa = pa;
let mut pb = pb;
for _ in 0..k {
let a0 = _mm256_loadu_pd(pa);
let a1 = _mm256_loadu_pd(pa.add(4));
let b0 = _mm256_broadcast_sd(&*pb);
let b1 = _mm256_broadcast_sd(&*pb.add(1));
let b2 = _mm256_broadcast_sd(&*pb.add(2));
let b3 = _mm256_broadcast_sd(&*pb.add(3));
mt00 = fmadd_pd(a0, b0, mt00);
mt01 = fmadd_pd(a0, b1, mt01);
mt02 = fmadd_pd(a0, b2, mt02);
mt03 = fmadd_pd(a0, b3, mt03);
mt10 = fmadd_pd(a1, b0, mt10);
mt11 = fmadd_pd(a1, b1, mt11);
mt12 = fmadd_pd(a1, b2, mt12);
mt13 = fmadd_pd(a1, b3, mt13);
pa = pa.add(8);
pb = pb.add(4);
}
let alpha = _mm256_broadcast_sd(&alpha);
mt00 = _mm256_mul_pd(alpha, mt00);
mt01 = _mm256_mul_pd(alpha, mt01);
mt02 = _mm256_mul_pd(alpha, mt02);
mt03 = _mm256_mul_pd(alpha, mt03);
mt10 = _mm256_mul_pd(alpha, mt10);
mt11 = _mm256_mul_pd(alpha, mt11);
mt12 = _mm256_mul_pd(alpha, mt12);
mt13 = _mm256_mul_pd(alpha, mt13);
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
let beta = _mm256_broadcast_sd(&beta);
mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00);
mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01);
mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02);
mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03);
mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10);
mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11);
mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12);
mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13);
}
_mm256_storeu_pd(ccol0, mt00);
_mm256_storeu_pd(ccol1, mt01);
_mm256_storeu_pd(ccol2, mt02);
_mm256_storeu_pd(ccol3, mt03);
_mm256_storeu_pd(ccol0.add(4), mt10);
_mm256_storeu_pd(ccol1.add(4), mt11);
_mm256_storeu_pd(ccol2.add(4), mt12);
_mm256_storeu_pd(ccol3.add(4), mt13);
}
pub(crate) unsafe fn dgemm_sup_8x1(
k: usize,
alpha: f64,
pa: *const f64,
b: *const f64,
beta: f64,
c: *mut f64,
) {
let mut mt0 = _mm256_setzero_pd();
let mut mt1 = _mm256_setzero_pd();
let mut pa = pa;
let mut b = b;
for _ in 0..k {
let a0 = _mm256_loadu_pd(pa);
let a1 = _mm256_loadu_pd(pa.add(4));
let b0 = _mm256_broadcast_sd(&*b);
mt0 = fmadd_pd(a0, b0, mt0);
mt1 = fmadd_pd(a1, b0, mt1);
pa = pa.add(8);
b = b.add(1);
}
let alpha = _mm256_broadcast_sd(&alpha);
mt0 = _mm256_mul_pd(alpha, mt0);
mt1 = _mm256_mul_pd(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_sd(&beta);
mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0);
mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1);
}
_mm256_storeu_pd(c, mt0);
_mm256_storeu_pd(c.add(4), mt1);
}
pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) {
let mut a = a;
let mut pa = pa;
for _ in 0..k {
_mm256_storeu_pd(pa, _mm256_loadu_pd(a));
_mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4)));
pa = pa.add(8);
a = a.add(lda);
}
}

178
src/kernel/sse/l3s.rs Normal file
View File

@ -0,0 +1,178 @@
use super::fma::fmadd_ps;
use super::intrinsics::*;
use crate::kernel::params::single::NR;
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
a: A,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
let mut c0_3 = _mm_setzero_ps();
let mut c4 = 0.0f32;
// let mut c5 = 0.0f32;
// let mut c6 = 0.0f32;
// let mut c7 = 0.0f32;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a.ptr();
let a0_simd = _mm_broadcast_ss(&*a.ptr());
c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3);
c4 += *pb.col(4) * a0;
// c5 += *pb.add(5) * a0;
// c6 += *pb.add(6) * a0;
// c7 += *pb.add(7) * a0;
a.inc_row();
pb.shift_col(NR);
}
// c0 *= alpha;
// c1 *= alpha;
// c2 *= alpha;
// c3 *= alpha;
// c4 *= alpha;
// c5 *= alpha;
// c6 *= alpha;
// c7 *= alpha;
let ccol0_3 = c.ptr_mut();
let ccol4 = c.row_mut(4);
// let ccol5 = c.add(ldc * 5);
// let ccol6 = c.add(ldc * 6);
// let ccol7 = c.add(ldc * 7);
// if beta != 0.0 {
// c0 += beta * *ccol0;
// c1 += beta * *ccol1;
// c2 += beta * *ccol2;
// c3 += beta * *ccol3;
// c4 += beta * *ccol4;
// // c5 += beta * *ccol5;
// // c6 += beta * *ccol6;
// // c7 += beta * *ccol7;
// }
_mm_storeu_ps(ccol0_3, c0_3);
*ccol4 = c4;
// *ccol5 = c5;
// *ccol6 = c6;
// *ccol7 = c7;
}
pub(crate) unsafe fn sgemm_sup_1x8_t(
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
pb: *const f32,
beta: f32,
c: *mut f32,
ldc: usize,
) {
let mut c03 = _mm_setzero_ps();
let mut c4 = 0.0f32;
// let mut c5 = 0.0f32;
// let mut c6 = 0.0f32;
// let mut c7 = 0.0f32;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
let a0_simd = _mm_set_ps1(a0);
c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03);
c4 += *pb.add(4) * a0;
// c5 += *pb.add(5) * a0;
// c6 += *pb.add(6) * a0;
// c7 += *pb.add(7) * a0;
a = a.add(1);
pb = pb.add(NR);
}
// c0 *= alpha;
// c1 *= alpha;
// c2 *= alpha;
// c3 *= alpha;
// c4 *= alpha;
// c5 *= alpha;
// c6 *= alpha;
// c7 *= alpha;
// let ccol0 = c;
// let ccol1 = c.add(ldc);
// let ccol2 = c.add(ldc * 2);
// let ccol3 = c.add(ldc * 3);
let ccol4 = c.add(ldc * 4);
// let ccol5 = c.add(ldc * 5);
// let ccol6 = c.add(ldc * 6);
// let ccol7 = c.add(ldc * 7);
// if beta != 0.0 {
// c0 += beta * *ccol0;
// c1 += beta * *ccol1;
// c2 += beta * *ccol2;
// c3 += beta * *ccol3;
// c4 += beta * *ccol4;
// // c5 += beta * *ccol5;
// // c6 += beta * *ccol6;
// // c7 += beta * *ccol7;
// }
_mm_storeu_ps(c, c03);
// *ccol0 = c0;
// *ccol1 = c1;
// *ccol2 = c2;
// *ccol3 = c3;
*ccol4 = c4;
// *ccol5 = c5;
// *ccol6 = c6;
// *ccol7 = c7;
}
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
let mut bcol0 = b;
let mut bcol1 = b.add(ldb);
let mut bcol2 = b.add(ldb * 2);
let mut bcol3 = b.add(ldb * 3);
let mut bcol4 = b.add(ldb * 4);
// let mut bcol5 = b.add(ldb * 5);
// let mut bcol6 = b.add(ldb * 6);
// let mut bcol7 = b.add(ldb * 7);
let mut pb = pb;
for _ in 0..k {
_mm_storeu_ps(pb, _mm_set_ps(*bcol3, *bcol2, *bcol1, *bcol0));
// *pb = *bcol0;
// *pb.add(1) = *bcol1;
// *pb.add(2) = *bcol2;
// *pb.add(3) = *bcol3;
*pb.add(4) = *bcol4;
// *pb.add(5) = *bcol5;
// *pb.add(6) = *bcol6;
// *pb.add(7) = *bcol7;
bcol0 = bcol0.add(1);
bcol1 = bcol1.add(1);
bcol2 = bcol2.add(1);
bcol3 = bcol3.add(1);
bcol4 = bcol4.add(1);
// bcol5 = bcol5.add(1);
// bcol6 = bcol6.add(1);
// bcol7 = bcol7.add(1);
pb = pb.add(NR);
}
}

86
src/kernel/sse/mod.rs Normal file
View File

@ -0,0 +1,86 @@
mod fma;
// mod hsum;
mod intrinsics;
// pub mod l1d;
// pub mod l1s;
// pub mod l3d;
pub mod l3s;
use core::marker::PhantomData;
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
use crate::dim::*;
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
where I: GemmKernelSupNr<f32, A5>
{
#[inline]
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
a: A,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
I::sup_tr(alpha, a, pb, beta, c);
}
}
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
where I: GemmKernelSupMr<f32, A16>
{
#[inline]
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
b: B,
beta: f32,
c: C,
) {
I::sup_bl(alpha, pa, b, beta, c);
}
}
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
where I: GemmKernelSup<f32>
{
#[inline]
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C,
) {
I::sup_br(k, alpha, a, b, beta, c);
}
}
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
where I: GemmKernel<f32, A16, A5>
{
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
I::pack_row_a(a, pa, i);
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
I::pack_row_b(b, pb, j);
}
#[inline]
unsafe fn main_tl<C: MatrixMut<f32>>(
alpha: f32,
pa: MutMatrix<f32>,
pb: MutMatrix<f32>,
beta: f32,
c: C,
) {
I::main_tl(alpha, pa, pb, beta, c);
}
}

17
src/lib.rs Normal file
View File

@ -0,0 +1,17 @@
pub mod matrix;
pub mod gemm;
mod sgemm;
mod aligned_alloc;
pub mod kernel;
pub mod dim;
pub mod executor;
#[cfg(test)]
extern crate blas;
#[cfg(test)]
extern crate openblas;
#[cfg(test)]
mod test;
pub use crate::sgemm::sgemm;

View File

@ -1,373 +1,16 @@
mod matrix;
mod gemm;
mod sgemm;
mod aligned_alloc;
mod kernel;
mod dim;
mod executor;
extern crate blas;
extern crate openblas;
use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix};
pub unsafe fn sgemm(
transa: bool,
transb: bool,
transc: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
beta: f32,
c: *mut f32,
ldc: usize,
) {
match (transa, transb, transc) {
(false, false, false) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, false, true) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(false, true, false) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, true, true) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, false, false) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, false, true) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, true, false) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, true, true) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
}
}
unsafe fn sgemm_ref_nn(
m: usize,
n: usize,
k: usize,
_alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
_beta: f32,
c: *mut f32,
ldc: usize,
) {
for j in 0..n {
for i in 0..m {
let mut ci = *c.add(i + j * ldc);
for p in 0..k {
ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
}
*c.add(i + j * ldc) = ci;
}
}
}
unsafe fn sgemm_ref_nt(
m: usize,
n: usize,
k: usize,
_alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
_beta: f32,
c: *mut f32,
ldc: usize,
) {
for j in 0..n {
for i in 0..m {
let mut ci = *c.add(i + j * ldc);
for p in 0..k {
ci += *a.add(i + p * lda) * *b.add(j + p * ldb);
}
*c.add(i + j * ldc) = ci;
}
}
}
#[inline(never)]
pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
_bata: f32,
c: &mut [f32], ldc: usize) {
let mk = m * k;
let nk = n * k;
let mn = m * n;
let a = &a[0..mk];
let b = &b[0..nk];
let c = &mut c[0..mn];
for i_m in 0..m {
for i_k in 0..k {
let a_part = alpha * a[i_m * lda + i_k];
for i_n in 0..n {
c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
}
}
}
}
fn main2() {
use rand::Rng;
let mut rng = rand::thread_rng();
const LEN: usize = 511;
let (m, n, k) = (LEN, LEN, LEN);
let mut a = vec![0.5; m * k];
let mut a_t = vec![0.5; m * k];
let mut b = vec![0.5; n * k];
let mut b_t = vec![0.5; n * k];
let mut c_nn = vec![0.0; m * n];
let mut c_nt = vec![0.0; m * n];
let mut c_tn = vec![0.0; m * n];
let mut c_tt = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
for i in 0..m {
for j in 0..k {
let v = rng.gen();
a[i + j * m] = v;
a_t[j + i * m] = v;
}
}
for i in 0..n {
for j in 0..k {
let v = rng.gen();
b[i + j * n] = v;
b_t[j + i * n] = v;
}
}
// let time = std::time::Instant::now();
// unsafe {
// gemm_nn(
// LEN,
// LEN,
// LEN,
// 1.0,
// a.as_slice(),
// LEN,
// b.as_slice(),
// LEN,
// 1.0,
// cref1.as_mut_slice(),
// LEN,
// )
// }
// println!("Naive (mine) {}", time.elapsed().as_millis());
let time = std::time::Instant::now();
unsafe {
sgemm_ref_nn(
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
1.0,
cref.as_mut_ptr(),
m,
)
}
println!("Naive {}", time.elapsed().as_millis());
let time = std::time::Instant::now();
unsafe {
sgemm(
false,
false,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
1.0,
c_nn.as_mut_ptr(),
m,
);
}
println!("[NN] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
false,
true,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b_t.as_ptr(),
n,
1.0,
c_nt.as_mut_ptr(),
m,
);
}
println!("[NT] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
true,
false,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b.as_ptr(),
k,
1.0,
c_tn.as_mut_ptr(),
m,
);
}
println!("[TN] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
true,
true,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b_t.as_ptr(),
n,
1.0,
c_tt.as_mut_ptr(),
m,
);
}
println!("[TT] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
}
use crate::sgemm::sgemm;
fn main() {
use rand::Rng;
@ -395,25 +38,24 @@ fn main() {
}
}
let time = std::time::Instant::now();
unsafe {
// blas::sgemm(
// b'N',
// b'N',
// m as i32,
// n as i32,
// k as i32,
// 1.0,
// a.as_slice(),
// m as i32,
// b.as_slice(),
// k as i32,
// 1.0,
// cref1.as_mut_slice(),
// m as i32,
// );
blas::sgemm(
b'N',
b'N',
m as i32,
n as i32,
k as i32,
1.0,
a.as_slice(),
m as i32,
b.as_slice(),
k as i32,
1.0,
cref1.as_mut_slice(),
m as i32,
);
// gemm_nn(
// m,
// n,
@ -456,6 +98,7 @@ fn main() {
let time = std::time::Instant::now();
unsafe {
sgemm(
&executor::RayonExecutor,
false,
false,
false,
@ -477,22 +120,5 @@ fn main() {
}
println!("Optimized {}", time.elapsed().as_millis());
// for i in 0..LEN {
// for j in 0..LEN {
// let (a, b) = (c[i + j * LEN], cref[i + j * LEN]);
// assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
// }
// }
}
fn feq(a: f32, b: f32) -> bool {
if a == b {
true
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
} else {
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
}
}

View File

@ -1,36 +1,50 @@
pub(crate) trait Matrix: Copy {
pub trait Number: Copy + Send + Sync + 'static {}
impl Number for f32 {}
impl Number for f64 {}
pub trait Matrix<F: Number>: Copy + Send + Sync + 'static {
unsafe fn sub(&self, row: usize, col: usize) -> Self;
unsafe fn sub_col(&self, col: usize) -> Self;
unsafe fn sub_row(&self, row: usize) -> Self;
unsafe fn is_transposed(&self) -> bool;
unsafe fn stride(&self) -> usize;
unsafe fn get(&self, row: usize, col: usize) -> f32;
unsafe fn index(&self, row: usize, col: usize) -> *const f32;
unsafe fn row(&self, row: usize) -> *const f32;
unsafe fn col(&self, col: usize) -> *const f32;
unsafe fn get(&self, row: usize, col: usize) -> F;
unsafe fn index(&self, row: usize, col: usize) -> *const F;
unsafe fn row(&self, row: usize) -> *const F;
unsafe fn col(&self, col: usize) -> *const F;
unsafe fn ptr(&self) -> *const F;
unsafe fn inc_row(&mut self);
unsafe fn inc_col(&mut self);
unsafe fn shift_row(&mut self, rows: usize);
unsafe fn shift_col(&mut self, cols: usize);
}
pub(crate) trait MatrixMut: Matrix {
unsafe fn set(&self, row: usize, col: usize, val: f32);
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32;
unsafe fn row_mut(&self, row: usize) -> *mut f32;
unsafe fn col_mut(&self, col: usize) -> *mut f32;
pub trait MatrixMut<F: Number>: Matrix<F> {
unsafe fn set(&self, row: usize, col: usize, val: F);
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F;
unsafe fn row_mut(&self, row: usize) -> *mut F;
unsafe fn col_mut(&self, col: usize) -> *mut F;
unsafe fn ptr_mut(&self) -> *mut F;
}
#[derive(Copy, Clone)]
pub(crate) struct ConstMatrix {
pub struct ConstMatrix<F: Number> {
pub stride: usize,
pub ptr: *const f32
pub ptr: *const F
}
unsafe impl Send for ConstMatrix {}
unsafe impl Sync for ConstMatrix {}
unsafe impl<F: Number> Send for ConstMatrix<F> {}
unsafe impl<F: Number> Sync for ConstMatrix<F> {}
impl ConstMatrix {
pub fn new(ptr: *const f32, stride: usize) -> Self {
impl<F: Number> ConstMatrix<F> {
pub fn new(ptr: *const F, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for ConstMatrix {
impl<F: Number> Matrix<F> for ConstMatrix<F> {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
@ -39,6 +53,22 @@ impl Matrix for ConstMatrix {
}
}
#[inline]
unsafe fn sub_col(&self, col: usize) -> Self {
Self {
ptr: self.col(col),
stride: self.stride
}
}
#[inline]
unsafe fn sub_row(&self, row: usize) -> Self {
Self {
ptr: self.row(row),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
false
@ -50,42 +80,67 @@ impl Matrix for ConstMatrix {
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
unsafe fn get(&self, row: usize, col: usize) -> F {
*self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
unsafe fn index(&self, row: usize, col: usize) -> *const F {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
unsafe fn row(&self, row: usize) -> *const F {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
unsafe fn col(&self, col: usize) -> *const F {
self.ptr.add(col)
}
#[inline]
unsafe fn ptr(&self) -> *const F {
self.ptr
}
#[inline]
unsafe fn inc_row(&mut self) {
self.ptr = self.row(1);
}
#[inline]
unsafe fn inc_col(&mut self) {
self.ptr = self.col(1);
}
#[inline]
unsafe fn shift_row(&mut self, rows: usize) {
self.ptr = self.row(rows);
}
#[inline]
unsafe fn shift_col(&mut self, cols: usize) {
self.ptr = self.col(cols);
}
}
#[derive(Copy, Clone)]
pub(crate) struct ConstTransposedMatrix {
pub struct ConstTransposedMatrix<F: Number> {
pub stride: usize,
pub ptr: *const f32
pub ptr: *const F
}
unsafe impl Send for ConstTransposedMatrix {}
unsafe impl Sync for ConstTransposedMatrix {}
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
unsafe impl<F: Number> Sync for ConstTransposedMatrix<F> {}
impl ConstTransposedMatrix {
pub fn new(ptr: *const f32, stride: usize) -> Self {
impl<F: Number> ConstTransposedMatrix<F> {
pub fn new(ptr: *const F, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for ConstTransposedMatrix {
impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
@ -94,6 +149,22 @@ impl Matrix for ConstTransposedMatrix {
}
}
#[inline]
unsafe fn sub_col(&self, col: usize) -> Self {
Self {
ptr: self.col(col),
stride: self.stride
}
}
#[inline]
unsafe fn sub_row(&self, row: usize) -> Self {
Self {
ptr: self.row(row),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
true
@ -105,42 +176,67 @@ impl Matrix for ConstTransposedMatrix {
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
unsafe fn get(&self, row: usize, col: usize) -> F {
*self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
unsafe fn index(&self, row: usize, col: usize) -> *const F {
self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
unsafe fn row(&self, row: usize) -> *const F {
self.ptr.add(row)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
unsafe fn col(&self, col: usize) -> *const F {
self.ptr.add(col * self.stride)
}
#[inline]
unsafe fn ptr(&self) -> *const F {
self.ptr
}
#[inline]
unsafe fn inc_row(&mut self) {
self.ptr = self.row(1);
}
#[inline]
unsafe fn inc_col(&mut self) {
self.ptr = self.col(1);
}
#[inline]
unsafe fn shift_row(&mut self, rows: usize) {
self.ptr = self.row(rows);
}
#[inline]
unsafe fn shift_col(&mut self, cols: usize) {
self.ptr = self.col(cols);
}
}
#[derive(Copy, Clone)]
pub(crate) struct MutMatrix {
pub struct MutMatrix<F: Number> {
pub stride: usize,
pub ptr: *mut f32
pub ptr: *mut F
}
unsafe impl Send for MutMatrix {}
unsafe impl Sync for MutMatrix {}
unsafe impl<F: Number> Send for MutMatrix<F> {}
unsafe impl<F: Number> Sync for MutMatrix<F> {}
impl MutMatrix {
pub fn new(ptr: *mut f32, stride: usize) -> Self {
impl<F: Number> MutMatrix<F> {
pub fn new(ptr: *mut F, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for MutMatrix {
impl<F: Number> Matrix<F> for MutMatrix<F> {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
@ -149,6 +245,22 @@ impl Matrix for MutMatrix {
}
}
#[inline]
unsafe fn sub_col(&self, col: usize) -> Self {
Self {
ptr: self.col_mut(col),
stride: self.stride
}
}
#[inline]
unsafe fn sub_row(&self, row: usize) -> Self {
Self {
ptr: self.row_mut(row),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
false
@ -160,64 +272,94 @@ impl Matrix for MutMatrix {
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
unsafe fn get(&self, row: usize, col: usize) -> F {
*self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
unsafe fn index(&self, row: usize, col: usize) -> *const F {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
unsafe fn row(&self, row: usize) -> *const F {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
unsafe fn col(&self, col: usize) -> *const F {
self.ptr.add(col)
}
#[inline]
unsafe fn ptr(&self) -> *const F {
self.ptr
}
#[inline]
unsafe fn inc_row(&mut self) {
self.ptr = self.row_mut(1);
}
#[inline]
unsafe fn inc_col(&mut self) {
self.ptr = self.col_mut(1);
}
#[inline]
unsafe fn shift_row(&mut self, rows: usize) {
self.ptr = self.row_mut(rows);
}
#[inline]
unsafe fn shift_col(&mut self, cols: usize) {
self.ptr = self.col_mut(cols);
}
}
impl MatrixMut for MutMatrix {
impl<F: Number> MatrixMut<F> for MutMatrix<F> {
#[inline]
unsafe fn set(&self, row: usize, col: usize, value: f32) {
unsafe fn set(&self, row: usize, col: usize, value: F) {
*self.ptr.add(row * self.stride + col) = value;
}
#[inline]
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
unsafe fn row_mut(&self, row: usize) -> *mut F {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
unsafe fn col_mut(&self, col: usize) -> *mut F {
self.ptr.add(col)
}
#[inline]
unsafe fn ptr_mut(&self) -> *mut F {
self.ptr
}
}
#[derive(Copy, Clone)]
pub(crate) struct MutTransposedMatrix {
pub struct MutTransposedMatrix<F: Number> {
pub stride: usize,
pub ptr: *mut f32
pub ptr: *mut F
}
unsafe impl Send for MutTransposedMatrix {}
unsafe impl Sync for MutTransposedMatrix {}
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
unsafe impl<F: Number> Sync for MutTransposedMatrix<F> {}
impl MutTransposedMatrix {
pub fn new(ptr: *mut f32, stride: usize) -> Self {
impl<F: Number> MutTransposedMatrix<F> {
pub fn new(ptr: *mut F, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for MutTransposedMatrix {
impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
@ -226,6 +368,22 @@ impl Matrix for MutTransposedMatrix {
}
}
#[inline]
unsafe fn sub_col(&self, col: usize) -> Self {
Self {
ptr: self.col_mut(col),
stride: self.stride
}
}
#[inline]
unsafe fn sub_row(&self, row: usize) -> Self {
Self {
ptr: self.row_mut(row),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
true
@ -237,47 +395,77 @@ impl Matrix for MutTransposedMatrix {
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
unsafe fn get(&self, row: usize, col: usize) -> F {
*self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
self.ptr.add(col * self.stride + row) as *const f32
unsafe fn index(&self, row: usize, col: usize) -> *const F {
self.ptr.add(col * self.stride + row) as *const F
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
self.ptr.add(row) as *const f32
unsafe fn row(&self, row: usize) -> *const F {
self.ptr.add(row) as *const F
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
self.ptr.add(col * self.stride) as *const f32
}
unsafe fn col(&self, col: usize) -> *const F {
self.ptr.add(col * self.stride) as *const F
}
impl MatrixMut for MutTransposedMatrix {
#[inline]
unsafe fn set(&self, row: usize, col: usize, value: f32) {
unsafe fn ptr(&self) -> *const F {
self.ptr
}
#[inline]
unsafe fn inc_row(&mut self) {
self.ptr = self.row_mut(1);
}
#[inline]
unsafe fn inc_col(&mut self) {
self.ptr = self.col_mut(1);
}
#[inline]
unsafe fn shift_row(&mut self, rows: usize) {
self.ptr = self.row_mut(rows);
}
#[inline]
unsafe fn shift_col(&mut self, cols: usize) {
self.ptr = self.col_mut(cols);
}
}
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
#[inline]
unsafe fn set(&self, row: usize, col: usize, value: F) {
*self.ptr.add(col * self.stride + row) = value;
}
#[inline]
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
unsafe fn row_mut(&self, row: usize) -> *mut F {
self.ptr.add(row)
}
#[inline]
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
unsafe fn col_mut(&self, col: usize) -> *mut F {
self.ptr.add(col * self.stride)
}
#[inline]
unsafe fn ptr_mut(&self) -> *mut F {
self.ptr
}
}

View File

@ -1,273 +1,30 @@
use crate::aligned_alloc;
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
use crate::kernel::params::single::*;
use crate::kernel;
use crate::kernel::avx::AvxKernel;
use crate::kernel::generic::GenericKernel;
use crate::dim::*;
use crate::executor::Executor;
pub(crate) unsafe fn sgemm<A, B, C>(
pub unsafe fn sgemm<E: Executor>(
e: &E,
transa: bool,
transb: bool,
transc: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: A,
b: B,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
beta: f32,
c: C
)
where A: Matrix,
B: Matrix,
C: MatrixMut,
{
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<f32>());
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<f32>());
for j in (0..n).step_by(NC) {
let j_b = std::cmp::min(n - j, NC);
for p in (0..k).step_by(KC) {
let p_b = std::cmp::min(k - p, KC);
for i in (0..m).step_by(MC) {
let i_b = std::cmp::min(m - i, MC);
let pa = MutMatrix::new(packed_a.ptr_f32(), p_b);
let pb = MutMatrix::new(packed_b.ptr_f32(), p_b);
inner_kernel(
i_b,
j_b,
p_b,
alpha,
a.sub(p, i),
b.sub(j, p),
beta,
c.sub(j, i),
pa,
pb,
i == 0
);
}
}
}
}
// +----------------------+
// |A : : : |B |
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// +----------------------+
// |C | | | |D |
// +----------------------+
#[inline]
unsafe fn sgemm_pa<A: Matrix>(a: A, pa: MutMatrix, i: usize) {
if a.is_transposed() {
kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
} else {
kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
}
}
#[inline]
unsafe fn sgemm_pb<B: Matrix>(b: B, pb: MutMatrix, j: usize) {
if b.is_transposed() {
kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
} else {
kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
}
}
#[inline]
unsafe fn sgemm_ukr<C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
pa: MutMatrix,
pb: MutMatrix,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
kernel::sgemm_ukr(
pa.stride,
alpha,
pa.row(i),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
}
}
#[inline]
unsafe fn sgemm_sup1<A: Matrix, C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
a: A,
pb: MutMatrix,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
if a.is_transposed() {
kernel::sgemm_sup1_t(
pb.stride,
alpha,
a.col(i),
a.stride(),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
} else {
kernel::sgemm_sup1(
pb.stride,
alpha,
a.col(i),
a.stride(),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
}
}
}
#[inline]
unsafe fn sgemm_sup0<B: Matrix, C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
pa: MutMatrix,
b: B,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
if b.is_transposed() {
kernel::sgemm_sup0_t(
pa.stride,
alpha,
pa.row(i),
b.row(j),
b.stride(),
beta,
c.index_mut(j, i))
} else {
kernel::sgemm_sup0(
pa.stride,
alpha,
pa.row(i),
b.row(j),
beta,
c.index_mut(j, i))
}
}
}
unsafe fn inner_kernel<A, B, C>(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C,
pa: MutMatrix,
pb: MutMatrix,
first_time: bool,
)
where A: Matrix,
B: Matrix,
C: MatrixMut,
{
let n_left = n % NR;
let n_main = n - n_left;
let m_left = m % MR;
let m_main = m - m_left;
if first_time {
for j in (0..n_main).step_by(NR) {
sgemm_pb(b, pb, j);
}
}
for i in (0..m_main).step_by(MR) {
sgemm_pa(a, pa, i);
};
for j in (0..n_main).step_by(NR) {
// Section A
for i in (0..m_main).step_by(MR) {
sgemm_ukr(
i,
j,
alpha,
pa,
pb,
beta,
c,
);
}
// Section B
for i in m_main..m {
sgemm_sup1(
i,
j,
alpha,
a,
pb,
beta,
c,
);
}
}
for j in n_main..n {
// Section C
for i in (0..m_main).step_by(MR) {
sgemm_sup0(
i,
j,
alpha,
pa,
b,
beta,
c,
);
}
// Section D
for i in m_main..m {
let mut elem = 0.0;
for p in 0..k {
elem += a.get(p, i) * b.get(j, p);
}
elem *= alpha;
if beta != 0.0 {
elem += beta * c.get(j, i);
}
c.set(j, i, elem);
}
}
c: *mut f32,
ldc: usize,
) {
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
e,
transa, transb, transc,
m, n, k, alpha,
a, lda, b, ldb,
beta, c, ldc);
}

308
src/test/l3s.rs Normal file
View File

@ -0,0 +1,308 @@
use rand::Rng;
const M_LEN: usize = 8192;
const N_LEN: usize = 8192;
const K_LEN: usize = 8192;
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let mut rng = rand::thread_rng();
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let mut a = vec![0.0; m * k];
let mut a_t = vec![0.0; m * k];
let mut b = vec![0.0; n * k];
let mut b_t = vec![0.0; n * k];
for row in 0..k {
for col in 0..m {
let v = rng.gen();
a[row * m + col] = v;
a_t[col * k + row] = v;
}
}
for row in 0..n {
for col in 0..k {
let v = rng.gen();
b[row * k + col] = v;
b_t[col * n + row] = v;
}
}
(a, a_t, b, b_t)
}
unsafe fn sgemm_ref(
m: usize,
n: usize,
k: usize,
_alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
_beta: f32,
c: *mut f32,
ldc: usize,
) {
for i_m in 0..m {
for i_k in 0..k {
let a_part = *a.add(i_m * lda + i_k);
for i_n in 0..n {
*c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n);
}
}
}
// for j in 0..n {
// for i in 0..m {
// let mut ci = *c.add(i + j * ldc);
// for p in 0..k {
// ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
// }
// *c.add(i + j * ldc) = ci;
// }
// }
}
#[test]
fn test_sgemm_nn() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, _, b, _) = make_matrices();
let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
unsafe {
blas::sgemm(
b'N',
b'N',
m as i32,
n as i32,
k as i32,
1.0,
a.as_slice(),
m as i32,
b.as_slice(),
k as i32,
0.0,
cref.as_mut_slice(),
m as i32,
)
}
unsafe {
crate::sgemm(
&crate::executor::DefaultExecutor,
false,
false,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
0.0,
c.as_mut_ptr(),
m,
);
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
}
}
}
#[test]
fn test_sgemm_nt() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, _, b, b_t) = make_matrices();
let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
unsafe {
blas::sgemm(
b'N',
b'T',
m as i32,
n as i32,
k as i32,
1.0,
a.as_slice(),
m as i32,
b_t.as_slice(),
n as i32,
0.0,
cref.as_mut_slice(),
m as i32,
)
}
unsafe {
crate::sgemm(
&crate::executor::DefaultExecutor,
false,
true,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b_t.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
m,
);
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
}
}
}
#[test]
fn test_sgemm_tn() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, a_t, b, _) = make_matrices();
let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
unsafe {
blas::sgemm(
b'T',
b'N',
m as i32,
n as i32,
k as i32,
1.0,
a_t.as_slice(),
k as i32,
b.as_slice(),
k as i32,
0.0,
cref.as_mut_slice(),
m as i32,
)
}
unsafe {
crate::sgemm(
&crate::executor::DefaultExecutor,
true,
false,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b.as_ptr(),
k,
0.0,
c.as_mut_ptr(),
m,
);
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
}
}
}
#[test]
fn test_sgemm_tt() {
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
let (a, a_t, b, b_t) = make_matrices();
let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
unsafe {
blas::sgemm(
b'T',
b'T',
m as i32,
n as i32,
k as i32,
1.0,
a_t.as_slice(),
k as i32,
b_t.as_slice(),
n as i32,
0.0,
cref.as_mut_slice(),
m as i32,
)
}
unsafe {
crate::sgemm(
&crate::executor::DefaultExecutor,
true,
true,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b_t.as_ptr(),
n,
0.0,
c.as_mut_ptr(),
m,
);
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
let (a, b) = (c[index], cref[index]);
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
}
}
}
fn feq(a: f32, b: f32) -> bool {
if a == b {
true
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
} else {
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
}
}

1
src/test/mod.rs Normal file
View File

@ -0,0 +1 @@
mod l3s;

View File

@ -1,5 +1,5 @@
#!/bin/bash
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
cargo clean && cargo build --release
#cargo clean && cargo build --release
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm