Generalization + SSE
This commit is contained in:
parent
ffae99bafd
commit
aaf8ec3203
156
Cargo.lock
generated
156
Cargo.lock
generated
@ -36,14 +36,66 @@ name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -56,17 +108,30 @@ dependencies = [
|
||||
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.65"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.2.3"
|
||||
name = "memoffset"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -86,6 +151,15 @@ dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas"
|
||||
version = "0.0.1"
|
||||
@ -134,9 +208,60 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
name = "rayon"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "threadpool"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
@ -149,16 +274,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
|
||||
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
|
||||
"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700"
|
||||
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
|
||||
"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
||||
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
|
||||
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
|
||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
|
||||
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
|
||||
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd"
|
||||
"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||
"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865"
|
||||
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
||||
|
@ -9,6 +9,11 @@ build = "build.rs"
|
||||
|
||||
[dependencies]
|
||||
rand = "0.7.2"
|
||||
matrixmultiply = "0.2.3"
|
||||
blas = "0.20.0"
|
||||
openblas = "0.0.1"
|
||||
crunchy = "0.2.2"
|
||||
rayon = "1.2.1"
|
||||
threadpool = "1.7.1"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
@ -13,12 +13,8 @@ impl Alloc {
|
||||
Alloc { ptr, layout }
|
||||
}
|
||||
|
||||
pub fn ptr_f32(&self) -> *mut f32 {
|
||||
self.ptr as *mut f32
|
||||
}
|
||||
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
pub fn ptr<F>(&self) -> *mut F {
|
||||
self.ptr as *mut F
|
||||
}
|
||||
}
|
||||
|
||||
|
75
src/dim.rs
Normal file
75
src/dim.rs
Normal file
@ -0,0 +1,75 @@
|
||||
pub trait Dim {
|
||||
const DIM: usize;
|
||||
}
|
||||
|
||||
pub struct A1;
|
||||
impl Dim for A1 {
|
||||
const DIM: usize = 1;
|
||||
}
|
||||
pub struct A2;
|
||||
impl Dim for A2 {
|
||||
const DIM: usize = 2;
|
||||
}
|
||||
pub struct A3;
|
||||
impl Dim for A3 {
|
||||
const DIM: usize = 3;
|
||||
}
|
||||
pub struct A4;
|
||||
impl Dim for A4 {
|
||||
const DIM: usize = 4;
|
||||
}
|
||||
pub struct A5;
|
||||
impl Dim for A5 {
|
||||
const DIM: usize = 5;
|
||||
}
|
||||
pub struct A6;
|
||||
impl Dim for A6 {
|
||||
const DIM: usize = 6;
|
||||
}
|
||||
pub struct A7;
|
||||
impl Dim for A7 {
|
||||
const DIM: usize = 7;
|
||||
}
|
||||
|
||||
pub struct A8;
|
||||
impl Dim for A8 {
|
||||
const DIM: usize = 8;
|
||||
}
|
||||
|
||||
pub struct A16;
|
||||
impl Dim for A16 {
|
||||
const DIM: usize = 16;
|
||||
}
|
||||
|
||||
/*
|
||||
macro_rules! gen {
|
||||
(($name: ident, $val: lit)) => {
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
gen! {
|
||||
(A0, 0),
|
||||
(A1, 1),
|
||||
(A2, 2),
|
||||
(A3, 3),
|
||||
(A4, 4),
|
||||
(A5, 5),
|
||||
(A6, 6),
|
||||
(A7, 7),
|
||||
(A8, 8),
|
||||
(A9, 9),
|
||||
(A10, 10),
|
||||
(A11, 11),
|
||||
(A12, 12),
|
||||
(A14, 14),
|
||||
(A16, 16),
|
||||
(A18, 18),
|
||||
(A20, 20),
|
||||
(A22, 22),
|
||||
(A24, 24),
|
||||
(A26, 26),
|
||||
(A28, 28),
|
||||
(A30, 30),
|
||||
(A32, 32),
|
||||
}*/
|
111
src/executor.rs
Normal file
111
src/executor.rs
Normal file
@ -0,0 +1,111 @@
|
||||
pub trait Executor {
|
||||
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||
&self,
|
||||
start: usize,
|
||||
end: usize,
|
||||
step: usize,
|
||||
f: F,
|
||||
);
|
||||
|
||||
fn synchronize(&self);
|
||||
}
|
||||
|
||||
pub struct DefaultExecutor;
|
||||
impl Executor for DefaultExecutor {
|
||||
#[inline]
|
||||
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||
&self,
|
||||
start: usize,
|
||||
end: usize,
|
||||
step: usize,
|
||||
f: F,
|
||||
) {
|
||||
for i in (start..end).step_by(step) {
|
||||
f(i);
|
||||
}
|
||||
}
|
||||
|
||||
fn synchronize(&self) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RayonExecutor;
|
||||
impl Executor for RayonExecutor {
|
||||
#[inline]
|
||||
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||
&self,
|
||||
start: usize,
|
||||
end: usize,
|
||||
step: usize,
|
||||
f: F,
|
||||
) {
|
||||
let end = 1 + ((end - 1) / step);
|
||||
|
||||
use rayon::prelude::*;
|
||||
(start..end)
|
||||
.into_par_iter()
|
||||
.map(|x| x * step)
|
||||
.for_each(|x|f(x));
|
||||
}
|
||||
|
||||
fn synchronize(&self) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct ThreadPoolExecutor {
|
||||
thread_pool: threadpool::ThreadPool
|
||||
}
|
||||
|
||||
impl ThreadPoolExecutor {
|
||||
pub fn new() -> ThreadPoolExecutor {
|
||||
ThreadPoolExecutor {
|
||||
thread_pool: threadpool::Builder::new().build()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for ThreadPoolExecutor {
|
||||
#[inline]
|
||||
fn execute<F: FnOnce(usize) + Send + 'static + Copy>(
|
||||
&self,
|
||||
start: usize,
|
||||
end: usize,
|
||||
step: usize,
|
||||
f: F,
|
||||
) {
|
||||
let thread_count = self.thread_pool.max_count();
|
||||
|
||||
let len = end - start;
|
||||
let num_steps = len / step;
|
||||
|
||||
let mut left_steps = num_steps % thread_count;
|
||||
let main_steps = num_steps - left_steps;
|
||||
|
||||
let job_size = main_steps / thread_count;
|
||||
|
||||
let mut prev_end = 0;
|
||||
|
||||
for _ in 0..thread_count {
|
||||
let mut now_end = prev_end + job_size;
|
||||
if left_steps > 0 {
|
||||
now_end += 1;
|
||||
left_steps -= 1;
|
||||
}
|
||||
self.thread_pool.execute(move || {
|
||||
for j in prev_end..now_end {
|
||||
f(start + j * step);
|
||||
}
|
||||
});
|
||||
|
||||
prev_end = now_end
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn synchronize(&self) {
|
||||
self.thread_pool.join();
|
||||
}
|
||||
}
|
242
src/gemm.rs
Normal file
242
src/gemm.rs
Normal file
@ -0,0 +1,242 @@
|
||||
use crate::aligned_alloc;
|
||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||
use crate::kernel::params::single::*;
|
||||
use crate::kernel;
|
||||
use crate::kernel::GemmKernel;
|
||||
use crate::dim::Dim;
|
||||
use crate::executor::Executor;
|
||||
|
||||
pub unsafe fn gemm<E, F, K, MR, NR>(
|
||||
e: &E,
|
||||
transa: bool,
|
||||
transb: bool,
|
||||
transc: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: F,
|
||||
a: *const F,
|
||||
lda: usize,
|
||||
b: *const F,
|
||||
ldb: usize,
|
||||
beta: F,
|
||||
c: *mut F,
|
||||
ldc: usize,
|
||||
)
|
||||
where E: Executor,
|
||||
F: Number,
|
||||
MR: Dim, NR: Dim,
|
||||
K: GemmKernel<F, MR, NR>,
|
||||
{
|
||||
match (transa, transb, transc) {
|
||||
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||
e, m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
|
||||
e: &E,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: F,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: F,
|
||||
c: C
|
||||
)
|
||||
where E: Executor,
|
||||
F: Number,
|
||||
MR: Dim, NR: Dim,
|
||||
K: GemmKernel<F, MR, NR>,
|
||||
A: Matrix<F>,
|
||||
B: Matrix<F>,
|
||||
C: MatrixMut<F>,
|
||||
{
|
||||
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<F>());
|
||||
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<F>());
|
||||
|
||||
for j in (0..n).step_by(NC) {
|
||||
let j_b = std::cmp::min(n - j, NC);
|
||||
for p in (0..k).step_by(KC) {
|
||||
let p_b = std::cmp::min(k - p, KC);
|
||||
for i in (0..m).step_by(MC) {
|
||||
let i_b = std::cmp::min(m - i, MC);
|
||||
|
||||
let pa = MutMatrix::new(packed_a.ptr::<F>(), p_b);
|
||||
let pb = MutMatrix::new(packed_b.ptr::<F>(), p_b);
|
||||
|
||||
inner_kernel::<E, F, K, MR, NR, _, _, _>(
|
||||
e,
|
||||
i_b,
|
||||
j_b,
|
||||
p_b,
|
||||
alpha,
|
||||
a.sub(p, i),
|
||||
b.sub(j, p),
|
||||
beta,
|
||||
c.sub(j, i),
|
||||
pa,
|
||||
pb,
|
||||
i == 0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// | MR |
|
||||
// +----------------------+
|
||||
// |TL : : : |TR|
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+----
|
||||
// | : : : | | NR
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+----
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// +----------------------+
|
||||
// |BL | | | |BR|
|
||||
// +----------------------+
|
||||
unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
e: &E,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: F,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: F,
|
||||
c: C,
|
||||
pa: MutMatrix<F>,
|
||||
pb: MutMatrix<F>,
|
||||
first_time: bool,
|
||||
)
|
||||
where E: Executor,
|
||||
F: Number,
|
||||
MR: Dim,
|
||||
NR: Dim,
|
||||
K: kernel::GemmKernel<F, MR, NR>,
|
||||
A: Matrix<F>,
|
||||
B: Matrix<F>,
|
||||
C: MatrixMut<F>,
|
||||
|
||||
{
|
||||
let n_left = n % NR;
|
||||
let n_main = n - n_left;
|
||||
|
||||
let m_left = m % MR;
|
||||
let m_main = m - m_left;
|
||||
|
||||
if first_time {
|
||||
e.execute(0, n_main, NR, move |j|
|
||||
K::pack_row_b(b, pb, j));
|
||||
}
|
||||
|
||||
e.execute(0, m_main, MR, move |i|
|
||||
K::pack_row_a(a, pa, i));
|
||||
|
||||
e.synchronize();
|
||||
|
||||
e.execute(0, n_main, NR, move |j| {
|
||||
// Section TL
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
K::main_tl(alpha,
|
||||
pa.sub_row(i),
|
||||
pb.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i));
|
||||
}
|
||||
|
||||
// Section TR
|
||||
for i in m_main..m {
|
||||
K::sup_tr(
|
||||
alpha,
|
||||
a.sub_col(i),
|
||||
pb.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i));
|
||||
}
|
||||
});
|
||||
|
||||
e.execute(n_main, n, 1, move |j| {
|
||||
// Section BL
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
K::sup_bl(
|
||||
alpha,
|
||||
pa.sub_row(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i)
|
||||
);
|
||||
}
|
||||
|
||||
// Section BR
|
||||
for i in m_main..m {
|
||||
K::sup_br(
|
||||
k,
|
||||
alpha,
|
||||
a.sub_col(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i))
|
||||
}
|
||||
});
|
||||
|
||||
e.synchronize();
|
||||
}
|
@ -246,6 +246,7 @@ pub unsafe fn sdot(
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
|
@ -1,29 +1,82 @@
|
||||
use super::fma::fmadd_ps;
|
||||
use super::intrinsics::*;
|
||||
use crate::kernel::params::single::{MR, NR};
|
||||
use crate::matrix::{Matrix, MutMatrix, MatrixMut};
|
||||
use crunchy::unroll;
|
||||
|
||||
pub(crate) unsafe fn sgemm_ukr_16x4(
|
||||
#[inline]
|
||||
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
pb: *const f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
c: C,
|
||||
) {
|
||||
let mut mt00 = _mm256_setzero_ps();
|
||||
let mut mt01 = _mm256_setzero_ps();
|
||||
let mut mt02 = _mm256_setzero_ps();
|
||||
let mut mt03 = _mm256_setzero_ps();
|
||||
let mut mt04 = _mm256_setzero_ps();
|
||||
// let mut mt05 = _mm256_setzero_ps();
|
||||
// let mut mt06 = _mm256_setzero_ps();
|
||||
// let mut mt07 = _mm256_setzero_ps();
|
||||
|
||||
let mut mt10 = _mm256_setzero_ps();
|
||||
let mut mt11 = _mm256_setzero_ps();
|
||||
let mut mt12 = _mm256_setzero_ps();
|
||||
let mut mt13 = _mm256_setzero_ps();
|
||||
let mut mt14 = _mm256_setzero_ps();
|
||||
// let mut mt15 = _mm256_setzero_ps();
|
||||
// let mut mt16 = _mm256_setzero_ps();
|
||||
// let mut mt17 = _mm256_setzero_ps();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut pb = pb;
|
||||
let mut pa = pa.ptr();
|
||||
let mut pb = pb.ptr();
|
||||
|
||||
for _ in 0..k {
|
||||
const BATCH: usize = 16;
|
||||
|
||||
let k_right = k % BATCH;
|
||||
let k_main = k - k_right;
|
||||
|
||||
for _ in (0..k_main).step_by(BATCH) {
|
||||
unroll! {
|
||||
for i in 0..16 {
|
||||
let a0 = _mm256_load_ps(pa.add(i * MR));
|
||||
let a1 = _mm256_load_ps(pa.add(i * MR + 8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*pb.add(i * NR));
|
||||
let b1 = _mm256_broadcast_ss(&*pb.add(i * NR + 1));
|
||||
let b2 = _mm256_broadcast_ss(&*pb.add(i * NR + 2));
|
||||
let b3 = _mm256_broadcast_ss(&*pb.add(i * NR + 3));
|
||||
let b4 = _mm256_broadcast_ss(&*pb.add(i * NR + 4));
|
||||
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||
|
||||
mt00 = fmadd_ps(a0, b0, mt00);
|
||||
mt01 = fmadd_ps(a0, b1, mt01);
|
||||
mt02 = fmadd_ps(a0, b2, mt02);
|
||||
mt03 = fmadd_ps(a0, b3, mt03);
|
||||
mt04 = fmadd_ps(a0, b4, mt04);
|
||||
// mt05 = fmadd_ps(a0, b5, mt05);
|
||||
// mt06 = fmadd_ps(a0, b6, mt06);
|
||||
// mt07 = fmadd_ps(a0, b7, mt07);
|
||||
|
||||
mt10 = fmadd_ps(a1, b0, mt10);
|
||||
mt11 = fmadd_ps(a1, b1, mt11);
|
||||
mt12 = fmadd_ps(a1, b2, mt12);
|
||||
mt13 = fmadd_ps(a1, b3, mt13);
|
||||
mt14 = fmadd_ps(a1, b4, mt14);
|
||||
}
|
||||
}
|
||||
|
||||
pa = pa.add(BATCH * MR);
|
||||
pb = pb.add(BATCH * NR);
|
||||
}
|
||||
|
||||
|
||||
for _ in k_main..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
|
||||
@ -31,66 +84,107 @@ pub(crate) unsafe fn sgemm_ukr_16x4(
|
||||
let b1 = _mm256_broadcast_ss(&*pb.add(1));
|
||||
let b2 = _mm256_broadcast_ss(&*pb.add(2));
|
||||
let b3 = _mm256_broadcast_ss(&*pb.add(3));
|
||||
let b4 = _mm256_broadcast_ss(&*pb.add(4));
|
||||
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||
|
||||
mt00 = fmadd_ps(a0, b0, mt00);
|
||||
mt01 = fmadd_ps(a0, b1, mt01);
|
||||
mt02 = fmadd_ps(a0, b2, mt02);
|
||||
mt03 = fmadd_ps(a0, b3, mt03);
|
||||
mt04 = fmadd_ps(a0, b4, mt04);
|
||||
// mt05 = fmadd_ps(a0, b5, mt05);
|
||||
// mt06 = fmadd_ps(a0, b6, mt06);
|
||||
// mt07 = fmadd_ps(a0, b7, mt07);
|
||||
|
||||
mt10 = fmadd_ps(a1, b0, mt10);
|
||||
mt11 = fmadd_ps(a1, b1, mt11);
|
||||
mt12 = fmadd_ps(a1, b2, mt12);
|
||||
mt13 = fmadd_ps(a1, b3, mt13);
|
||||
mt14 = fmadd_ps(a1, b4, mt14);
|
||||
|
||||
pa = pa.add(16);
|
||||
pb = pb.add(4);
|
||||
pa = pa.add(MR);
|
||||
pb = pb.add(NR);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
// let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt00 = _mm256_mul_ps(alpha, mt00);
|
||||
mt01 = _mm256_mul_ps(alpha, mt01);
|
||||
mt02 = _mm256_mul_ps(alpha, mt02);
|
||||
mt03 = _mm256_mul_ps(alpha, mt03);
|
||||
mt10 = _mm256_mul_ps(alpha, mt10);
|
||||
mt11 = _mm256_mul_ps(alpha, mt11);
|
||||
mt12 = _mm256_mul_ps(alpha, mt12);
|
||||
mt13 = _mm256_mul_ps(alpha, mt13);
|
||||
// mt00 = _mm256_mul_ps(alpha, mt00);
|
||||
// mt01 = _mm256_mul_ps(alpha, mt01);
|
||||
// mt02 = _mm256_mul_ps(alpha, mt02);
|
||||
// mt03 = _mm256_mul_ps(alpha, mt03);
|
||||
// mt04 = _mm256_mul_ps(alpha, mt04);
|
||||
// // mt05 = _mm256_mul_ps(alpha, mt05);
|
||||
// // mt06 = _mm256_mul_ps(alpha, mt06);
|
||||
// // mt07 = _mm256_mul_ps(alpha, mt07);
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
// mt10 = _mm256_mul_ps(alpha, mt10);
|
||||
// mt11 = _mm256_mul_ps(alpha, mt11);
|
||||
// mt12 = _mm256_mul_ps(alpha, mt12);
|
||||
// mt13 = _mm256_mul_ps(alpha, mt13);
|
||||
// mt14 = _mm256_mul_ps(alpha, mt14);
|
||||
// // mt15 = _mm256_mul_ps(alpha, mt15);
|
||||
// // mt16 = _mm256_mul_ps(alpha, mt16);
|
||||
// // mt17 = _mm256_mul_ps(alpha, mt17);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
let ccol0 = c.ptr_mut();
|
||||
let ccol1 = c.row_mut(1);
|
||||
let ccol2 = c.row_mut(2);
|
||||
let ccol3 = c.row_mut(3);
|
||||
let ccol4 = c.row_mut(4);
|
||||
// let ccol5 = c.row_mut(5);
|
||||
// let ccol6 = c.row_mut(6);
|
||||
// let ccol7 = c.add(ldc * 7);
|
||||
|
||||
mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
|
||||
mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
|
||||
mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
|
||||
mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
|
||||
mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||
mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||
mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||
mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
|
||||
}
|
||||
// if beta != 0.0 {
|
||||
// let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
// mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
|
||||
// mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
|
||||
// mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
|
||||
// mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
|
||||
// mt04 = fmadd_ps(beta, _mm256_loadu_ps(ccol4), mt04);
|
||||
// // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05);
|
||||
// // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06);
|
||||
// // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07);
|
||||
|
||||
// mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||
// mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||
// mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||
// mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
|
||||
// mt14 = fmadd_ps(beta, _mm256_loadu_ps(ccol4.add(8)), mt14);
|
||||
// // mt15 = fmadd_ps(beta, _mm256_loadu_ps(ccol5.add(8)), mt15);
|
||||
// // mt16 = fmadd_ps(beta, _mm256_loadu_ps(ccol6.add(8)), mt16);
|
||||
// // mt17 = fmadd_ps(beta, _mm256_loadu_ps(ccol7.add(8)), mt17);
|
||||
// }
|
||||
|
||||
_mm256_storeu_ps(ccol0, mt00);
|
||||
_mm256_storeu_ps(ccol1, mt01);
|
||||
_mm256_storeu_ps(ccol2, mt02);
|
||||
_mm256_storeu_ps(ccol3, mt03);
|
||||
_mm256_storeu_ps(ccol4, mt04);
|
||||
// _mm256_storeu_ps(ccol5, mt05);
|
||||
// _mm256_storeu_ps(ccol6, mt06);
|
||||
// _mm256_storeu_ps(ccol7, mt07);
|
||||
|
||||
_mm256_storeu_ps(ccol0.add(8), mt10);
|
||||
_mm256_storeu_ps(ccol1.add(8), mt11);
|
||||
_mm256_storeu_ps(ccol2.add(8), mt12);
|
||||
_mm256_storeu_ps(ccol3.add(8), mt13);
|
||||
_mm256_storeu_ps(ccol4.add(8), mt14);
|
||||
// _mm256_storeu_ps(ccol5.add(8), mt15);
|
||||
// _mm256_storeu_ps(ccol6.add(8), mt16);
|
||||
// _mm256_storeu_ps(ccol7.add(8), mt17);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_16x1(
|
||||
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
b: *const f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_ps();
|
||||
let mut mt1 = _mm256_setzero_ps();
|
||||
@ -99,76 +193,32 @@ pub(crate) unsafe fn sgemm_sup_16x1(
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
let a0 = _mm256_load_ps(pa.ptr());
|
||||
let a1 = _mm256_load_ps(pa.col(8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*b);
|
||||
let b0 = _mm256_broadcast_ss(&*b.ptr());
|
||||
|
||||
mt0 = fmadd_ps(a0, b0, mt0);
|
||||
mt1 = fmadd_ps(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(16);
|
||||
b = b.add(1);
|
||||
pa.shift_col(16);
|
||||
b.inc_col();
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
// let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
// mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
// mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
// if beta != 0.0 {
|
||||
// let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
}
|
||||
// mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||
// mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
// }
|
||||
|
||||
_mm256_storeu_ps(c, mt0);
|
||||
_mm256_storeu_ps(c.add(8), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup0_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_ps();
|
||||
let mut mt1 = _mm256_setzero_ps();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*b);
|
||||
|
||||
mt0 = fmadd_ps(a0, b0, mt0);
|
||||
mt1 = fmadd_ps(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(16);
|
||||
b = b.add(ldb);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(c, mt0);
|
||||
_mm256_storeu_ps(c.add(8), mt1);
|
||||
_mm256_storeu_ps(c.ptr_mut(), mt0);
|
||||
_mm256_storeu_ps(c.col_mut(8), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
|
@ -5,3 +5,85 @@ pub mod l1d;
|
||||
pub mod l1s;
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
|
||||
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||
|
||||
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSupNr<f32, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::sup_tr(alpha, a, pb, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSupMr<f32, A16>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
self::l3s::sgemm_sup_16x1(pa.stride, alpha, pa, b, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
||||
where I: GemmKernelSup<f32>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::sup_br(k, alpha, a, b, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
if a.is_transposed() {
|
||||
I::pack_row_a(a, pa, i);
|
||||
} else {
|
||||
self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
I::pack_row_b(b, pb, j);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c);
|
||||
}
|
||||
}
|
@ -1,114 +1,91 @@
|
||||
pub(crate) unsafe fn sgemm_sup_1x4(
|
||||
use crate::kernel::params::single::NR;
|
||||
use crate::matrix::{MutMatrix, Matrix, MatrixMut};
|
||||
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
pb: *const f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
c: C,
|
||||
) {
|
||||
let mut c0 = 0.0;
|
||||
let mut c1 = 0.0;
|
||||
let mut c2 = 0.0;
|
||||
let mut c3 = 0.0;
|
||||
let mut c0 = 0.0f32;
|
||||
let mut c1 = 0.0f32;
|
||||
let mut c2 = 0.0f32;
|
||||
let mut c3 = 0.0f32;
|
||||
let mut c4 = 0.0f32;
|
||||
// let mut c5 = 0.0f32;
|
||||
// let mut c6 = 0.0f32;
|
||||
// let mut c7 = 0.0f32;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
let a0 = *a.ptr();
|
||||
|
||||
c0 += *pb * a0;
|
||||
c1 += *pb.add(1) * a0;
|
||||
c2 += *pb.add(2) * a0;
|
||||
c3 += *pb.add(3) * a0;
|
||||
c0 += *pb.ptr() * a0;
|
||||
c1 += *pb.col(1) * a0;
|
||||
c2 += *pb.col(2) * a0;
|
||||
c3 += *pb.col(3) * a0;
|
||||
c4 += *pb.col(4) * a0;
|
||||
// c5 += *pb.col(5) * a0;
|
||||
// c6 += *pb.col(6) * a0;
|
||||
// c7 += *pb.col(7) * a0;
|
||||
|
||||
a = a.add(lda);
|
||||
pb = pb.add(4);
|
||||
a.inc_row();
|
||||
pb.shift_col(NR);
|
||||
}
|
||||
|
||||
c0 *= alpha;
|
||||
c1 *= alpha;
|
||||
c2 *= alpha;
|
||||
c3 *= alpha;
|
||||
// c0 *= alpha;
|
||||
// c1 *= alpha;
|
||||
// c2 *= alpha;
|
||||
// c3 *= alpha;
|
||||
// c4 *= alpha;
|
||||
// // c5 *= alpha;
|
||||
// // c6 *= alpha;
|
||||
// // c7 *= alpha;
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
let ccol0 = c.ptr_mut();
|
||||
let ccol1 = c.row_mut(1);
|
||||
let ccol2 = c.row_mut(2);
|
||||
let ccol3 = c.row_mut(3);
|
||||
let ccol4 = c.row_mut(4);
|
||||
// let ccol5 = c.row_mut(5);
|
||||
// let ccol6 = c.row_mut(6);
|
||||
// let ccol7 = c.add(ldc * 7);
|
||||
|
||||
if beta != 0.0 {
|
||||
c0 += beta * *ccol0;
|
||||
c1 += beta * *ccol1;
|
||||
c2 += beta * *ccol2;
|
||||
c3 += beta * *ccol3;
|
||||
}
|
||||
// if beta != 0.0 {
|
||||
// c0 += beta * *ccol0;
|
||||
// c1 += beta * *ccol1;
|
||||
// c2 += beta * *ccol2;
|
||||
// c3 += beta * *ccol3;
|
||||
// c4 += beta * *ccol4;
|
||||
// // c5 += beta * *ccol5;
|
||||
// // c6 += beta * *ccol6;
|
||||
// // c7 += beta * *ccol7;
|
||||
// }
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
*ccol4 = c4;
|
||||
// *ccol5 = c5;
|
||||
// *ccol6 = c6;
|
||||
// *ccol7 = c7;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup1_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut c0 = 0.0;
|
||||
let mut c1 = 0.0;
|
||||
let mut c2 = 0.0;
|
||||
let mut c3 = 0.0;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
|
||||
c0 += *pb * a0;
|
||||
c1 += *pb.add(1) * a0;
|
||||
c2 += *pb.add(2) * a0;
|
||||
c3 += *pb.add(3) * a0;
|
||||
|
||||
a = a.add(1);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
c0 *= alpha;
|
||||
c1 *= alpha;
|
||||
c2 *= alpha;
|
||||
c3 *= alpha;
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
c0 += beta * *ccol0;
|
||||
c1 += beta * *ccol1;
|
||||
c2 += beta * *ccol2;
|
||||
c3 += beta * *ccol3;
|
||||
}
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
let mut bcol0 = b;
|
||||
let mut bcol1 = b.add(ldb);
|
||||
let mut bcol2 = b.add(ldb * 2);
|
||||
let mut bcol3 = b.add(ldb * 3);
|
||||
let mut bcol4 = b.add(ldb * 4);
|
||||
// let mut bcol5 = b.add(ldb * 5);
|
||||
// let mut bcol6 = b.add(ldb * 6);
|
||||
// let mut bcol7 = b.row(7);
|
||||
|
||||
let mut pb = pb;
|
||||
|
||||
@ -117,27 +94,24 @@ pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f
|
||||
*pb.add(1) = *bcol1;
|
||||
*pb.add(2) = *bcol2;
|
||||
*pb.add(3) = *bcol3;
|
||||
*pb.add(4) = *bcol4;
|
||||
// *pb.add(5) = *bcol5;
|
||||
// *pb.add(6) = *bcol6;
|
||||
// *pb.col(7) = *bcol7;
|
||||
|
||||
bcol0 = bcol0.add(1);
|
||||
bcol1 = bcol1.add(1);
|
||||
bcol2 = bcol2.add(1);
|
||||
bcol3 = bcol3.add(1);
|
||||
pb = pb.add(4);
|
||||
bcol4 = bcol4.add(1);
|
||||
// bcol5 = bcol5.add(1);
|
||||
// bcol6 = bcol6.add(1);
|
||||
// bcol7 = bcol7.add(1);
|
||||
|
||||
pb = pb.add(NR);
|
||||
}
|
||||
}
|
||||
|
||||
// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
// use crate::kernel::params::single::MR;
|
||||
// let mut a = a;
|
||||
// let mut pa = pa;
|
||||
|
||||
// for p in 0..k {
|
||||
// for j in 0..MR {
|
||||
// *pa.add(j * k + p) = *a.add(p * lda + j);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
use crate::kernel::params::single::NR;
|
||||
let mut b = b;
|
||||
|
@ -1,2 +1,98 @@
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
|
||||
pub struct GenericKernel;
|
||||
|
||||
impl GemmKernelSupNr<f32, A5> for GenericKernel {
|
||||
#[inline]
|
||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
self::l3s::sgemm_sup_1x8(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a,
|
||||
pb,
|
||||
beta,
|
||||
c)
|
||||
}
|
||||
}
|
||||
|
||||
impl GemmKernelSupMr<f32, A16> for GenericKernel {
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GemmKernelSup<f32> for GenericKernel {
|
||||
#[inline]
|
||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut elem = 0.0;
|
||||
|
||||
for p in 0..k {
|
||||
elem += *a.row(p) * *b.col(p);
|
||||
}
|
||||
|
||||
elem *= alpha;
|
||||
|
||||
if beta != 0.0 {
|
||||
elem += beta * *c.ptr();
|
||||
}
|
||||
|
||||
*c.ptr_mut() = elem;
|
||||
}
|
||||
}
|
||||
|
||||
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
if a.is_transposed() {
|
||||
self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
if b.is_transposed() {
|
||||
self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
} else {
|
||||
self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -2,53 +2,20 @@
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
mod avx;
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub use avx::{l1d::*, l1s::*};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use avx::l3s::{
|
||||
sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t,
|
||||
};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use avx::l3d::{
|
||||
dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr,
|
||||
};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1};
|
||||
|
||||
|
||||
mod generic;
|
||||
pub mod avx;
|
||||
pub mod generic;
|
||||
pub mod sse;
|
||||
|
||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
|
||||
use crate::dim::Dim;
|
||||
|
||||
pub mod params {
|
||||
pub mod single {
|
||||
pub const MC: usize = 256;
|
||||
pub const KC: usize = 128;
|
||||
pub const NC: usize = 1024;
|
||||
pub const MC: usize = 128;
|
||||
pub const KC: usize = 256;
|
||||
pub const NC: usize = 8*1024;
|
||||
pub const MR: usize = 16;
|
||||
pub const NR: usize = 4;
|
||||
pub const NR: usize = 5;
|
||||
}
|
||||
pub mod double {
|
||||
pub const MC: usize = 256;
|
||||
@ -58,3 +25,60 @@ pub mod params {
|
||||
pub const NR: usize = 4;
|
||||
}
|
||||
}
|
||||
|
||||
// +----------------------+
|
||||
// |TL : : : |TR|
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// +----------------------+
|
||||
// |BL | | | |BR|
|
||||
// +----------------------+
|
||||
|
||||
pub trait GemmKernelSup<F: Number> {
|
||||
unsafe fn sup_br<A: Matrix<F>, B: Matrix<F>, C: MatrixMut<F>>(
|
||||
k: usize,
|
||||
alpha: F,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: F,
|
||||
c: C,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait GemmKernelSupMr<F: Number, MR: Dim> {
|
||||
unsafe fn sup_bl<B: Matrix<F>, C: MatrixMut<F>>(
|
||||
alpha: F,
|
||||
pa: MutMatrix<F>,
|
||||
b: B,
|
||||
beta: F,
|
||||
c: C,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait GemmKernelSupNr<F: Number, NR: Dim> {
|
||||
unsafe fn sup_tr<A: Matrix<F>, C: MatrixMut<F>>(
|
||||
alpha: F,
|
||||
a: A,
|
||||
pb: MutMatrix<F>,
|
||||
beta: F,
|
||||
c: C,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
|
||||
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>, i: usize);
|
||||
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>, j: usize);
|
||||
|
||||
unsafe fn main_tl<C: MatrixMut<F>>(
|
||||
alpha: F,
|
||||
pa: MutMatrix<F>,
|
||||
pb: MutMatrix<F>,
|
||||
beta: F,
|
||||
c: C,
|
||||
);
|
||||
}
|
47
src/kernel/sse/fma.rs
Normal file
47
src/kernel/sse/fma.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use super::intrinsics::*;
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||
_mm_fmadd_ps(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||
_mm_add_ps(_mm_mul_ps(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||
_mm_fmsub_ps(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||
_mm_sub_ps(_mm_mul_ps(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||
_mm_fmadd_pd(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||
_mm_add_pd(_mm_mul_pd(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||
_mm_fmsub_pd(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||
_mm_sub_pd(_mm_mul_pd(a, b), c)
|
||||
}
|
23
src/kernel/sse/hsum.rs
Normal file
23
src/kernel/sse/hsum.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use super::intrinsics::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn hsum_ps(v: __m256) -> f32 {
|
||||
let qhigh = _mm256_extractf128_ps(v, 1);
|
||||
let qlow = _mm256_castps256_ps128(v);
|
||||
let qsum = _mm_add_ps(qhigh, qlow);
|
||||
let dhigh = _mm_movehl_ps(qsum, qsum);
|
||||
let dlow = qsum;
|
||||
let dsum = _mm_add_ps(dhigh, dlow);
|
||||
let high = _mm_shuffle_ps(dsum, dsum, 1);
|
||||
let low = dsum;
|
||||
_mm_cvtss_f32(_mm_add_ss(high, low))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn hsum_pd(v: __m256d) -> f64 {
|
||||
let vhigh = _mm256_extractf128_pd(v, 1);
|
||||
let vlow = _mm256_castpd256_pd128(v);
|
||||
let vsum = _mm_add_pd(vlow, vhigh);
|
||||
let h64 = _mm_unpackhi_pd(vsum, vsum);
|
||||
_mm_cvtsd_f64(_mm_add_sd(vsum, h64))
|
||||
}
|
5
src/kernel/sse/intrinsics.rs
Normal file
5
src/kernel/sse/intrinsics.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(target_arch = "x86")]
|
||||
pub use std::arch::x86::*;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub use std::arch::x86_64::*;
|
396
src/kernel/sse/l1d.rs
Normal file
396
src/kernel/sse/l1d.rs
Normal file
@ -0,0 +1,396 @@
|
||||
use super::fma::{fmadd_pd, fmsub_pd};
|
||||
use super::hsum::hsum_pd;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub unsafe fn drot(
|
||||
n: usize,
|
||||
mut x: *mut f64,
|
||||
incx: usize,
|
||||
mut y: *mut f64,
|
||||
incy: usize,
|
||||
c: f64,
|
||||
s: f64,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let c0 = _mm256_broadcast_sd(&c);
|
||||
let s0 = _mm256_broadcast_sd(&s);
|
||||
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0)));
|
||||
_mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0)));
|
||||
_mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1)));
|
||||
_mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1)));
|
||||
_mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2)));
|
||||
_mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2)));
|
||||
_mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3)));
|
||||
_mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3)));
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(x, y0);
|
||||
_mm256_storeu_pd(y, x0);
|
||||
_mm256_storeu_pd(x.add(4), y1);
|
||||
_mm256_storeu_pd(y.add(4), x1);
|
||||
_mm256_storeu_pd(x.add(8), y2);
|
||||
_mm256_storeu_pd(y.add(8), x2);
|
||||
_mm256_storeu_pd(x.add(12), y3);
|
||||
_mm256_storeu_pd(y.add(12), x3);
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) {
|
||||
if incx == 1 {
|
||||
let a0 = _mm256_broadcast_sd(&a);
|
||||
for _ in 0..n / 32 {
|
||||
let mut x0 = _mm256_loadu_pd(x);
|
||||
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
x0 = _mm256_mul_pd(a0, x0);
|
||||
x1 = _mm256_mul_pd(a0, x1);
|
||||
x2 = _mm256_mul_pd(a0, x2);
|
||||
x3 = _mm256_mul_pd(a0, x3);
|
||||
x4 = _mm256_mul_pd(a0, x4);
|
||||
x5 = _mm256_mul_pd(a0, x5);
|
||||
x6 = _mm256_mul_pd(a0, x6);
|
||||
x7 = _mm256_mul_pd(a0, x7);
|
||||
|
||||
_mm256_storeu_pd(x, x0);
|
||||
_mm256_storeu_pd(x.add(4), x1);
|
||||
_mm256_storeu_pd(x.add(8), x2);
|
||||
_mm256_storeu_pd(x.add(12), x3);
|
||||
_mm256_storeu_pd(x.add(16), x4);
|
||||
_mm256_storeu_pd(x.add(20), x5);
|
||||
_mm256_storeu_pd(x.add(24), x6);
|
||||
_mm256_storeu_pd(x.add(28), x7);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*x *= a;
|
||||
x = x.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*x *= a;
|
||||
x = x.add(incx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let x4 = _mm256_loadu_pd(x.add(16));
|
||||
let x5 = _mm256_loadu_pd(x.add(20));
|
||||
let x6 = _mm256_loadu_pd(x.add(24));
|
||||
let x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
_mm256_storeu_pd(y, x0);
|
||||
_mm256_storeu_pd(y.add(4), x1);
|
||||
_mm256_storeu_pd(y.add(8), x2);
|
||||
_mm256_storeu_pd(y.add(12), x3);
|
||||
_mm256_storeu_pd(y.add(16), x4);
|
||||
_mm256_storeu_pd(y.add(20), x5);
|
||||
_mm256_storeu_pd(y.add(24), x6);
|
||||
_mm256_storeu_pd(y.add(28), x7);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*y = *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y = *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn daxpy(
|
||||
n: usize,
|
||||
a: f64,
|
||||
mut x: *const f64,
|
||||
incx: usize,
|
||||
mut y: *mut f64,
|
||||
incy: usize,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let a0 = _mm256_broadcast_sd(&a);
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(y, fmadd_pd(a0, x0, y0));
|
||||
_mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1));
|
||||
_mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2));
|
||||
_mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3));
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
for _ in 0..n % 16 {
|
||||
*y += a * *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y += a * *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn ddot(
|
||||
n: usize,
|
||||
mut x: *const f64,
|
||||
incx: usize,
|
||||
mut y: *const f64,
|
||||
incy: usize,
|
||||
) -> f64 {
|
||||
if incx == 1 && incy == 1 {
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
acc0 = fmadd_pd(x0, y0, acc0);
|
||||
acc1 = fmadd_pd(x1, y1, acc1);
|
||||
acc2 = fmadd_pd(x2, y2, acc2);
|
||||
acc3 = fmadd_pd(x3, y3, acc3);
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 16 {
|
||||
acc += *x * *y;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += *x * *y;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||
if incx == 1 {
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
|
||||
acc0 = fmadd_pd(x0, x0, acc0);
|
||||
acc1 = fmadd_pd(x1, x1, acc1);
|
||||
acc2 = fmadd_pd(x2, x2, acc2);
|
||||
acc3 = fmadd_pd(x3, x3, acc3);
|
||||
|
||||
x = x.add(16);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(1);
|
||||
}
|
||||
acc.sqrt()
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||
if incx == 1 {
|
||||
let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF));
|
||||
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
let mut acc4 = _mm256_setzero_pd();
|
||||
let mut acc5 = _mm256_setzero_pd();
|
||||
let mut acc6 = _mm256_setzero_pd();
|
||||
let mut acc7 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 32 {
|
||||
let mut x0 = _mm256_loadu_pd(x);
|
||||
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
x0 = _mm256_and_pd(mask, x0);
|
||||
x1 = _mm256_and_pd(mask, x1);
|
||||
x2 = _mm256_and_pd(mask, x2);
|
||||
x3 = _mm256_and_pd(mask, x3);
|
||||
x4 = _mm256_and_pd(mask, x4);
|
||||
x5 = _mm256_and_pd(mask, x5);
|
||||
x6 = _mm256_and_pd(mask, x6);
|
||||
x7 = _mm256_and_pd(mask, x7);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, x0);
|
||||
acc1 = _mm256_add_pd(acc1, x1);
|
||||
acc2 = _mm256_add_pd(acc2, x2);
|
||||
acc3 = _mm256_add_pd(acc3, x3);
|
||||
acc4 = _mm256_add_pd(acc4, x4);
|
||||
acc5 = _mm256_add_pd(acc5, x5);
|
||||
acc6 = _mm256_add_pd(acc6, x6);
|
||||
acc7 = _mm256_add_pd(acc7, x7);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc4 = _mm256_add_pd(acc4, acc5);
|
||||
acc6 = _mm256_add_pd(acc6, acc7);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
acc4 = _mm256_add_pd(acc4, acc6);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, acc4);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
acc += (*x).abs();
|
||||
x = x.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += (*x).abs();
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
396
src/kernel/sse/l1s.rs
Normal file
396
src/kernel/sse/l1s.rs
Normal file
@ -0,0 +1,396 @@
|
||||
use super::fma::{fmadd_ps, fmsub_ps};
|
||||
use super::hsum::hsum_ps;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub unsafe fn srot(
|
||||
n: usize,
|
||||
mut x: *mut f32,
|
||||
incx: usize,
|
||||
mut y: *mut f32,
|
||||
incy: usize,
|
||||
c: f32,
|
||||
s: f32,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let c0 = _mm256_broadcast_ss(&c);
|
||||
let s0 = _mm256_broadcast_ss(&s);
|
||||
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0)));
|
||||
_mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0)));
|
||||
_mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1)));
|
||||
_mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1)));
|
||||
_mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2)));
|
||||
_mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2)));
|
||||
_mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3)));
|
||||
_mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3)));
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(x, y0);
|
||||
_mm256_storeu_ps(y, x0);
|
||||
_mm256_storeu_ps(x.add(8), y1);
|
||||
_mm256_storeu_ps(y.add(8), x1);
|
||||
_mm256_storeu_ps(x.add(16), y2);
|
||||
_mm256_storeu_ps(y.add(16), x2);
|
||||
_mm256_storeu_ps(x.add(24), y3);
|
||||
_mm256_storeu_ps(y.add(24), x3);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) {
|
||||
if incx == 1 {
|
||||
let a0 = _mm256_broadcast_ss(&a);
|
||||
for _ in 0..n / 64 {
|
||||
let mut x0 = _mm256_loadu_ps(x);
|
||||
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
x0 = _mm256_mul_ps(a0, x0);
|
||||
x1 = _mm256_mul_ps(a0, x1);
|
||||
x2 = _mm256_mul_ps(a0, x2);
|
||||
x3 = _mm256_mul_ps(a0, x3);
|
||||
x4 = _mm256_mul_ps(a0, x4);
|
||||
x5 = _mm256_mul_ps(a0, x5);
|
||||
x6 = _mm256_mul_ps(a0, x6);
|
||||
x7 = _mm256_mul_ps(a0, x7);
|
||||
|
||||
_mm256_storeu_ps(x, x0);
|
||||
_mm256_storeu_ps(x.add(8), x1);
|
||||
_mm256_storeu_ps(x.add(16), x2);
|
||||
_mm256_storeu_ps(x.add(24), x3);
|
||||
_mm256_storeu_ps(x.add(32), x4);
|
||||
_mm256_storeu_ps(x.add(40), x5);
|
||||
_mm256_storeu_ps(x.add(48), x6);
|
||||
_mm256_storeu_ps(x.add(56), x7);
|
||||
|
||||
x = x.add(64);
|
||||
}
|
||||
for _ in 0..n % 64 {
|
||||
*x *= a;
|
||||
x = x.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*x *= a;
|
||||
x = x.add(incx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 64 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let x4 = _mm256_loadu_ps(x.add(32));
|
||||
let x5 = _mm256_loadu_ps(x.add(40));
|
||||
let x6 = _mm256_loadu_ps(x.add(48));
|
||||
let x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
_mm256_storeu_ps(y, x0);
|
||||
_mm256_storeu_ps(y.add(8), x1);
|
||||
_mm256_storeu_ps(y.add(16), x2);
|
||||
_mm256_storeu_ps(y.add(24), x3);
|
||||
_mm256_storeu_ps(y.add(32), x4);
|
||||
_mm256_storeu_ps(y.add(40), x5);
|
||||
_mm256_storeu_ps(y.add(48), x6);
|
||||
_mm256_storeu_ps(y.add(56), x7);
|
||||
|
||||
x = x.add(64);
|
||||
y = y.add(64);
|
||||
}
|
||||
for _ in 0..n % 64 {
|
||||
*y = *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y = *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn saxpy(
|
||||
n: usize,
|
||||
a: f32,
|
||||
mut x: *const f32,
|
||||
incx: usize,
|
||||
mut y: *mut f32,
|
||||
incy: usize,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let a0 = _mm256_broadcast_ss(&a);
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(y, fmadd_ps(a0, x0, y0));
|
||||
_mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1));
|
||||
_mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2));
|
||||
_mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3));
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*y += a * *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y += a * *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sdot(
|
||||
n: usize,
|
||||
mut x: *const f32,
|
||||
incx: usize,
|
||||
mut y: *const f32,
|
||||
incy: usize,
|
||||
) -> f32 {
|
||||
if incx == 1 && incy == 1 {
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
acc0 = fmadd_ps(x0, y0, acc0);
|
||||
acc1 = fmadd_ps(x1, y1, acc1);
|
||||
acc2 = fmadd_ps(x2, y2, acc2);
|
||||
acc3 = fmadd_ps(x3, y3, acc3);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
acc += *x * *y;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += *x * *y;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||
if incx == 1 {
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
|
||||
acc0 = fmadd_ps(x0, x0, acc0);
|
||||
acc1 = fmadd_ps(x1, x1, acc1);
|
||||
acc2 = fmadd_ps(x2, x2, acc2);
|
||||
acc3 = fmadd_ps(x3, x3, acc3);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(1);
|
||||
}
|
||||
acc.sqrt()
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||
if incx == 1 {
|
||||
let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF));
|
||||
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
let mut acc4 = _mm256_setzero_ps();
|
||||
let mut acc5 = _mm256_setzero_ps();
|
||||
let mut acc6 = _mm256_setzero_ps();
|
||||
let mut acc7 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 64 {
|
||||
let mut x0 = _mm256_loadu_ps(x);
|
||||
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
x0 = _mm256_and_ps(mask, x0);
|
||||
x1 = _mm256_and_ps(mask, x1);
|
||||
x2 = _mm256_and_ps(mask, x2);
|
||||
x3 = _mm256_and_ps(mask, x3);
|
||||
x4 = _mm256_and_ps(mask, x4);
|
||||
x5 = _mm256_and_ps(mask, x5);
|
||||
x6 = _mm256_and_ps(mask, x6);
|
||||
x7 = _mm256_and_ps(mask, x7);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, x0);
|
||||
acc1 = _mm256_add_ps(acc1, x1);
|
||||
acc2 = _mm256_add_ps(acc2, x2);
|
||||
acc3 = _mm256_add_ps(acc3, x3);
|
||||
acc4 = _mm256_add_ps(acc4, x4);
|
||||
acc5 = _mm256_add_ps(acc5, x5);
|
||||
acc6 = _mm256_add_ps(acc6, x6);
|
||||
acc7 = _mm256_add_ps(acc7, x7);
|
||||
|
||||
x = x.add(64);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc4 = _mm256_add_ps(acc4, acc5);
|
||||
acc6 = _mm256_add_ps(acc6, acc7);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
acc4 = _mm256_add_ps(acc4, acc6);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, acc4);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 64 {
|
||||
acc += (*x).abs();
|
||||
x = x.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += (*x).abs();
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
140
src/kernel/sse/l3d.rs
Normal file
140
src/kernel/sse/l3d.rs
Normal file
@ -0,0 +1,140 @@
|
||||
use super::fma::fmadd_pd;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub(crate) unsafe fn dgemm_ukr_8x4(
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
pa: *const f64,
|
||||
pb: *const f64,
|
||||
beta: f64,
|
||||
c: *mut f64,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut mt00 = _mm256_setzero_pd();
|
||||
let mut mt01 = _mm256_setzero_pd();
|
||||
let mut mt02 = _mm256_setzero_pd();
|
||||
let mut mt03 = _mm256_setzero_pd();
|
||||
let mut mt10 = _mm256_setzero_pd();
|
||||
let mut mt11 = _mm256_setzero_pd();
|
||||
let mut mt12 = _mm256_setzero_pd();
|
||||
let mut mt13 = _mm256_setzero_pd();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_loadu_pd(pa);
|
||||
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||
|
||||
let b0 = _mm256_broadcast_sd(&*pb);
|
||||
let b1 = _mm256_broadcast_sd(&*pb.add(1));
|
||||
let b2 = _mm256_broadcast_sd(&*pb.add(2));
|
||||
let b3 = _mm256_broadcast_sd(&*pb.add(3));
|
||||
|
||||
mt00 = fmadd_pd(a0, b0, mt00);
|
||||
mt01 = fmadd_pd(a0, b1, mt01);
|
||||
mt02 = fmadd_pd(a0, b2, mt02);
|
||||
mt03 = fmadd_pd(a0, b3, mt03);
|
||||
mt10 = fmadd_pd(a1, b0, mt10);
|
||||
mt11 = fmadd_pd(a1, b1, mt11);
|
||||
mt12 = fmadd_pd(a1, b2, mt12);
|
||||
mt13 = fmadd_pd(a1, b3, mt13);
|
||||
|
||||
pa = pa.add(8);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_sd(&alpha);
|
||||
|
||||
mt00 = _mm256_mul_pd(alpha, mt00);
|
||||
mt01 = _mm256_mul_pd(alpha, mt01);
|
||||
mt02 = _mm256_mul_pd(alpha, mt02);
|
||||
mt03 = _mm256_mul_pd(alpha, mt03);
|
||||
mt10 = _mm256_mul_pd(alpha, mt10);
|
||||
mt11 = _mm256_mul_pd(alpha, mt11);
|
||||
mt12 = _mm256_mul_pd(alpha, mt12);
|
||||
mt13 = _mm256_mul_pd(alpha, mt13);
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_sd(&beta);
|
||||
|
||||
mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00);
|
||||
mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01);
|
||||
mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02);
|
||||
mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03);
|
||||
mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10);
|
||||
mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11);
|
||||
mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12);
|
||||
mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(ccol0, mt00);
|
||||
_mm256_storeu_pd(ccol1, mt01);
|
||||
_mm256_storeu_pd(ccol2, mt02);
|
||||
_mm256_storeu_pd(ccol3, mt03);
|
||||
_mm256_storeu_pd(ccol0.add(4), mt10);
|
||||
_mm256_storeu_pd(ccol1.add(4), mt11);
|
||||
_mm256_storeu_pd(ccol2.add(4), mt12);
|
||||
_mm256_storeu_pd(ccol3.add(4), mt13);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn dgemm_sup_8x1(
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
pa: *const f64,
|
||||
b: *const f64,
|
||||
beta: f64,
|
||||
c: *mut f64,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_pd();
|
||||
let mut mt1 = _mm256_setzero_pd();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_loadu_pd(pa);
|
||||
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||
|
||||
let b0 = _mm256_broadcast_sd(&*b);
|
||||
|
||||
mt0 = fmadd_pd(a0, b0, mt0);
|
||||
mt1 = fmadd_pd(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(8);
|
||||
b = b.add(1);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_sd(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_pd(alpha, mt0);
|
||||
mt1 = _mm256_mul_pd(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_sd(&beta);
|
||||
|
||||
mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0);
|
||||
mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(c, mt0);
|
||||
_mm256_storeu_pd(c.add(4), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) {
|
||||
let mut a = a;
|
||||
let mut pa = pa;
|
||||
|
||||
for _ in 0..k {
|
||||
_mm256_storeu_pd(pa, _mm256_loadu_pd(a));
|
||||
_mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4)));
|
||||
|
||||
pa = pa.add(8);
|
||||
a = a.add(lda);
|
||||
}
|
||||
}
|
178
src/kernel/sse/l3s.rs
Normal file
178
src/kernel/sse/l3s.rs
Normal file
@ -0,0 +1,178 @@
|
||||
use super::fma::fmadd_ps;
|
||||
use super::intrinsics::*;
|
||||
use crate::kernel::params::single::NR;
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
let mut c0_3 = _mm_setzero_ps();
|
||||
let mut c4 = 0.0f32;
|
||||
// let mut c5 = 0.0f32;
|
||||
// let mut c6 = 0.0f32;
|
||||
// let mut c7 = 0.0f32;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a.ptr();
|
||||
let a0_simd = _mm_broadcast_ss(&*a.ptr());
|
||||
|
||||
c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3);
|
||||
c4 += *pb.col(4) * a0;
|
||||
// c5 += *pb.add(5) * a0;
|
||||
// c6 += *pb.add(6) * a0;
|
||||
// c7 += *pb.add(7) * a0;
|
||||
|
||||
a.inc_row();
|
||||
pb.shift_col(NR);
|
||||
}
|
||||
|
||||
// c0 *= alpha;
|
||||
// c1 *= alpha;
|
||||
// c2 *= alpha;
|
||||
// c3 *= alpha;
|
||||
// c4 *= alpha;
|
||||
// c5 *= alpha;
|
||||
// c6 *= alpha;
|
||||
// c7 *= alpha;
|
||||
|
||||
let ccol0_3 = c.ptr_mut();
|
||||
let ccol4 = c.row_mut(4);
|
||||
// let ccol5 = c.add(ldc * 5);
|
||||
// let ccol6 = c.add(ldc * 6);
|
||||
// let ccol7 = c.add(ldc * 7);
|
||||
|
||||
// if beta != 0.0 {
|
||||
// c0 += beta * *ccol0;
|
||||
// c1 += beta * *ccol1;
|
||||
// c2 += beta * *ccol2;
|
||||
// c3 += beta * *ccol3;
|
||||
// c4 += beta * *ccol4;
|
||||
// // c5 += beta * *ccol5;
|
||||
// // c6 += beta * *ccol6;
|
||||
// // c7 += beta * *ccol7;
|
||||
// }
|
||||
|
||||
_mm_storeu_ps(ccol0_3, c0_3);
|
||||
*ccol4 = c4;
|
||||
// *ccol5 = c5;
|
||||
// *ccol6 = c6;
|
||||
// *ccol7 = c7;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_1x8_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut c03 = _mm_setzero_ps();
|
||||
let mut c4 = 0.0f32;
|
||||
// let mut c5 = 0.0f32;
|
||||
// let mut c6 = 0.0f32;
|
||||
// let mut c7 = 0.0f32;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
let a0_simd = _mm_set_ps1(a0);
|
||||
|
||||
c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03);
|
||||
c4 += *pb.add(4) * a0;
|
||||
// c5 += *pb.add(5) * a0;
|
||||
// c6 += *pb.add(6) * a0;
|
||||
// c7 += *pb.add(7) * a0;
|
||||
|
||||
a = a.add(1);
|
||||
pb = pb.add(NR);
|
||||
}
|
||||
|
||||
// c0 *= alpha;
|
||||
// c1 *= alpha;
|
||||
// c2 *= alpha;
|
||||
// c3 *= alpha;
|
||||
// c4 *= alpha;
|
||||
// c5 *= alpha;
|
||||
// c6 *= alpha;
|
||||
// c7 *= alpha;
|
||||
|
||||
// let ccol0 = c;
|
||||
// let ccol1 = c.add(ldc);
|
||||
// let ccol2 = c.add(ldc * 2);
|
||||
// let ccol3 = c.add(ldc * 3);
|
||||
let ccol4 = c.add(ldc * 4);
|
||||
// let ccol5 = c.add(ldc * 5);
|
||||
// let ccol6 = c.add(ldc * 6);
|
||||
// let ccol7 = c.add(ldc * 7);
|
||||
|
||||
// if beta != 0.0 {
|
||||
// c0 += beta * *ccol0;
|
||||
// c1 += beta * *ccol1;
|
||||
// c2 += beta * *ccol2;
|
||||
// c3 += beta * *ccol3;
|
||||
// c4 += beta * *ccol4;
|
||||
// // c5 += beta * *ccol5;
|
||||
// // c6 += beta * *ccol6;
|
||||
// // c7 += beta * *ccol7;
|
||||
// }
|
||||
|
||||
_mm_storeu_ps(c, c03);
|
||||
// *ccol0 = c0;
|
||||
// *ccol1 = c1;
|
||||
// *ccol2 = c2;
|
||||
// *ccol3 = c3;
|
||||
*ccol4 = c4;
|
||||
// *ccol5 = c5;
|
||||
// *ccol6 = c6;
|
||||
// *ccol7 = c7;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
let mut bcol0 = b;
|
||||
let mut bcol1 = b.add(ldb);
|
||||
let mut bcol2 = b.add(ldb * 2);
|
||||
let mut bcol3 = b.add(ldb * 3);
|
||||
let mut bcol4 = b.add(ldb * 4);
|
||||
// let mut bcol5 = b.add(ldb * 5);
|
||||
// let mut bcol6 = b.add(ldb * 6);
|
||||
// let mut bcol7 = b.add(ldb * 7);
|
||||
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
_mm_storeu_ps(pb, _mm_set_ps(*bcol3, *bcol2, *bcol1, *bcol0));
|
||||
|
||||
// *pb = *bcol0;
|
||||
// *pb.add(1) = *bcol1;
|
||||
// *pb.add(2) = *bcol2;
|
||||
// *pb.add(3) = *bcol3;
|
||||
*pb.add(4) = *bcol4;
|
||||
// *pb.add(5) = *bcol5;
|
||||
// *pb.add(6) = *bcol6;
|
||||
// *pb.add(7) = *bcol7;
|
||||
|
||||
bcol0 = bcol0.add(1);
|
||||
bcol1 = bcol1.add(1);
|
||||
bcol2 = bcol2.add(1);
|
||||
bcol3 = bcol3.add(1);
|
||||
bcol4 = bcol4.add(1);
|
||||
// bcol5 = bcol5.add(1);
|
||||
// bcol6 = bcol6.add(1);
|
||||
// bcol7 = bcol7.add(1);
|
||||
|
||||
pb = pb.add(NR);
|
||||
}
|
||||
}
|
86
src/kernel/sse/mod.rs
Normal file
86
src/kernel/sse/mod.rs
Normal file
@ -0,0 +1,86 @@
|
||||
mod fma;
|
||||
// mod hsum;
|
||||
mod intrinsics;
|
||||
// pub mod l1d;
|
||||
// pub mod l1s;
|
||||
// pub mod l3d;
|
||||
pub mod l3s;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||
use crate::dim::*;
|
||||
|
||||
|
||||
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||
|
||||
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
||||
where I: GemmKernelSupNr<f32, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::sup_tr(alpha, a, pb, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
||||
where I: GemmKernelSupMr<f32, A16>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::sup_bl(alpha, pa, b, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
||||
where I: GemmKernelSup<f32>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::sup_br(k, alpha, a, b, beta, c);
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
I::pack_row_a(a, pa, i);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
I::pack_row_b(b, pb, j);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||
alpha: f32,
|
||||
pa: MutMatrix<f32>,
|
||||
pb: MutMatrix<f32>,
|
||||
beta: f32,
|
||||
c: C,
|
||||
) {
|
||||
I::main_tl(alpha, pa, pb, beta, c);
|
||||
}
|
||||
}
|
17
src/lib.rs
Normal file
17
src/lib.rs
Normal file
@ -0,0 +1,17 @@
|
||||
pub mod matrix;
|
||||
pub mod gemm;
|
||||
mod sgemm;
|
||||
mod aligned_alloc;
|
||||
pub mod kernel;
|
||||
pub mod dim;
|
||||
pub mod executor;
|
||||
|
||||
#[cfg(test)]
|
||||
extern crate blas;
|
||||
#[cfg(test)]
|
||||
extern crate openblas;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
|
||||
pub use crate::sgemm::sgemm;
|
418
src/main.rs
418
src/main.rs
@ -1,373 +1,16 @@
|
||||
mod matrix;
|
||||
mod gemm;
|
||||
mod sgemm;
|
||||
mod aligned_alloc;
|
||||
mod kernel;
|
||||
mod dim;
|
||||
mod executor;
|
||||
|
||||
|
||||
extern crate blas;
|
||||
extern crate openblas;
|
||||
|
||||
use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||
|
||||
pub unsafe fn sgemm(
|
||||
transa: bool,
|
||||
transb: bool,
|
||||
transc: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
match (transa, transb, transc) {
|
||||
(false, false, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, false, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
unsafe fn sgemm_ref_nn(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
for j in 0..n {
|
||||
for i in 0..m {
|
||||
let mut ci = *c.add(i + j * ldc);
|
||||
for p in 0..k {
|
||||
ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
|
||||
}
|
||||
*c.add(i + j * ldc) = ci;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn sgemm_ref_nt(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
for j in 0..n {
|
||||
for i in 0..m {
|
||||
let mut ci = *c.add(i + j * ldc);
|
||||
|
||||
for p in 0..k {
|
||||
ci += *a.add(i + p * lda) * *b.add(j + p * ldb);
|
||||
}
|
||||
|
||||
*c.add(i + j * ldc) = ci;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
_bata: f32,
|
||||
c: &mut [f32], ldc: usize) {
|
||||
|
||||
let mk = m * k;
|
||||
let nk = n * k;
|
||||
let mn = m * n;
|
||||
let a = &a[0..mk];
|
||||
let b = &b[0..nk];
|
||||
let c = &mut c[0..mn];
|
||||
|
||||
for i_m in 0..m {
|
||||
for i_k in 0..k {
|
||||
let a_part = alpha * a[i_m * lda + i_k];
|
||||
|
||||
for i_n in 0..n {
|
||||
c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn main2() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const LEN: usize = 511;
|
||||
let (m, n, k) = (LEN, LEN, LEN);
|
||||
|
||||
let mut a = vec![0.5; m * k];
|
||||
let mut a_t = vec![0.5; m * k];
|
||||
let mut b = vec![0.5; n * k];
|
||||
let mut b_t = vec![0.5; n * k];
|
||||
let mut c_nn = vec![0.0; m * n];
|
||||
let mut c_nt = vec![0.0; m * n];
|
||||
let mut c_tn = vec![0.0; m * n];
|
||||
let mut c_tt = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
for i in 0..m {
|
||||
for j in 0..k {
|
||||
let v = rng.gen();
|
||||
a[i + j * m] = v;
|
||||
a_t[j + i * m] = v;
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..k {
|
||||
let v = rng.gen();
|
||||
b[i + j * n] = v;
|
||||
b_t[j + i * n] = v;
|
||||
}
|
||||
}
|
||||
|
||||
// let time = std::time::Instant::now();
|
||||
// unsafe {
|
||||
// gemm_nn(
|
||||
// LEN,
|
||||
// LEN,
|
||||
// LEN,
|
||||
// 1.0,
|
||||
// a.as_slice(),
|
||||
// LEN,
|
||||
// b.as_slice(),
|
||||
// LEN,
|
||||
// 1.0,
|
||||
// cref1.as_mut_slice(),
|
||||
// LEN,
|
||||
// )
|
||||
// }
|
||||
|
||||
// println!("Naive (mine) {}", time.elapsed().as_millis());
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm_ref_nn(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
cref.as_mut_ptr(),
|
||||
m,
|
||||
)
|
||||
}
|
||||
|
||||
println!("Naive {}", time.elapsed().as_millis());
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
c_nn.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
println!("[NN] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
|
||||
1.0,
|
||||
c_nt.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
println!("[NT] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
c_tn.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
println!("[TN] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
|
||||
1.0,
|
||||
c_tt.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
println!("[TT] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use crate::sgemm::sgemm;
|
||||
|
||||
fn main() {
|
||||
use rand::Rng;
|
||||
@ -395,25 +38,24 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
// blas::sgemm(
|
||||
// b'N',
|
||||
// b'N',
|
||||
// m as i32,
|
||||
// n as i32,
|
||||
// k as i32,
|
||||
// 1.0,
|
||||
// a.as_slice(),
|
||||
// m as i32,
|
||||
// b.as_slice(),
|
||||
// k as i32,
|
||||
// 1.0,
|
||||
// cref1.as_mut_slice(),
|
||||
// m as i32,
|
||||
// );
|
||||
blas::sgemm(
|
||||
b'N',
|
||||
b'N',
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
1.0,
|
||||
a.as_slice(),
|
||||
m as i32,
|
||||
b.as_slice(),
|
||||
k as i32,
|
||||
1.0,
|
||||
cref1.as_mut_slice(),
|
||||
m as i32,
|
||||
);
|
||||
|
||||
// gemm_nn(
|
||||
// m,
|
||||
// n,
|
||||
@ -456,6 +98,7 @@ fn main() {
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
&executor::RayonExecutor,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
@ -477,22 +120,5 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
|
||||
// for i in 0..LEN {
|
||||
// for j in 0..LEN {
|
||||
// let (a, b) = (c[i + j * LEN], cref[i + j * LEN]);
|
||||
// assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
fn feq(a: f32, b: f32) -> bool {
|
||||
if a == b {
|
||||
true
|
||||
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
|
||||
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
|
||||
} else {
|
||||
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
||||
}
|
||||
}
|
||||
|
322
src/matrix.rs
322
src/matrix.rs
@ -1,36 +1,50 @@
|
||||
pub(crate) trait Matrix: Copy {
|
||||
|
||||
pub trait Number: Copy + Send + Sync + 'static {}
|
||||
|
||||
impl Number for f32 {}
|
||||
impl Number for f64 {}
|
||||
|
||||
pub trait Matrix<F: Number>: Copy + Send + Sync + 'static {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self;
|
||||
unsafe fn sub_col(&self, col: usize) -> Self;
|
||||
unsafe fn sub_row(&self, row: usize) -> Self;
|
||||
unsafe fn is_transposed(&self) -> bool;
|
||||
unsafe fn stride(&self) -> usize;
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32;
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32;
|
||||
unsafe fn row(&self, row: usize) -> *const f32;
|
||||
unsafe fn col(&self, col: usize) -> *const f32;
|
||||
unsafe fn get(&self, row: usize, col: usize) -> F;
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const F;
|
||||
unsafe fn row(&self, row: usize) -> *const F;
|
||||
unsafe fn col(&self, col: usize) -> *const F;
|
||||
unsafe fn ptr(&self) -> *const F;
|
||||
unsafe fn inc_row(&mut self);
|
||||
unsafe fn inc_col(&mut self);
|
||||
unsafe fn shift_row(&mut self, rows: usize);
|
||||
unsafe fn shift_col(&mut self, cols: usize);
|
||||
}
|
||||
|
||||
pub(crate) trait MatrixMut: Matrix {
|
||||
unsafe fn set(&self, row: usize, col: usize, val: f32);
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32;
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32;
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32;
|
||||
pub trait MatrixMut<F: Number>: Matrix<F> {
|
||||
unsafe fn set(&self, row: usize, col: usize, val: F);
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F;
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut F;
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut F;
|
||||
unsafe fn ptr_mut(&self) -> *mut F;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct ConstMatrix {
|
||||
pub struct ConstMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *const f32
|
||||
pub ptr: *const F
|
||||
}
|
||||
|
||||
unsafe impl Send for ConstMatrix {}
|
||||
unsafe impl Sync for ConstMatrix {}
|
||||
unsafe impl<F: Number> Send for ConstMatrix<F> {}
|
||||
unsafe impl<F: Number> Sync for ConstMatrix<F> {}
|
||||
|
||||
impl ConstMatrix {
|
||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
||||
impl<F: Number> ConstMatrix<F> {
|
||||
pub fn new(ptr: *const F, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for ConstMatrix {
|
||||
impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
@ -39,6 +53,22 @@ impl Matrix for ConstMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col(col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row(row),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
false
|
||||
@ -50,42 +80,67 @@ impl Matrix for ConstMatrix {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||
*self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
unsafe fn row(&self, row: usize) -> *const F {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
unsafe fn col(&self, col: usize) -> *const F {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr(&self) -> *const F {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_row(&mut self) {
|
||||
self.ptr = self.row(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_col(&mut self) {
|
||||
self.ptr = self.col(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_row(&mut self, rows: usize) {
|
||||
self.ptr = self.row(rows);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_col(&mut self, cols: usize) {
|
||||
self.ptr = self.col(cols);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct ConstTransposedMatrix {
|
||||
pub struct ConstTransposedMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *const f32
|
||||
pub ptr: *const F
|
||||
}
|
||||
|
||||
unsafe impl Send for ConstTransposedMatrix {}
|
||||
unsafe impl Sync for ConstTransposedMatrix {}
|
||||
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
|
||||
unsafe impl<F: Number> Sync for ConstTransposedMatrix<F> {}
|
||||
|
||||
impl ConstTransposedMatrix {
|
||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
||||
impl<F: Number> ConstTransposedMatrix<F> {
|
||||
pub fn new(ptr: *const F, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for ConstTransposedMatrix {
|
||||
impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
@ -94,6 +149,22 @@ impl Matrix for ConstTransposedMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col(col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row(row),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
true
|
||||
@ -105,42 +176,67 @@ impl Matrix for ConstTransposedMatrix {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||
*self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||
self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
unsafe fn row(&self, row: usize) -> *const F {
|
||||
self.ptr.add(row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
unsafe fn col(&self, col: usize) -> *const F {
|
||||
self.ptr.add(col * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr(&self) -> *const F {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_row(&mut self) {
|
||||
self.ptr = self.row(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_col(&mut self) {
|
||||
self.ptr = self.col(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_row(&mut self, rows: usize) {
|
||||
self.ptr = self.row(rows);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_col(&mut self, cols: usize) {
|
||||
self.ptr = self.col(cols);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct MutMatrix {
|
||||
pub struct MutMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut f32
|
||||
pub ptr: *mut F
|
||||
}
|
||||
|
||||
unsafe impl Send for MutMatrix {}
|
||||
unsafe impl Sync for MutMatrix {}
|
||||
unsafe impl<F: Number> Send for MutMatrix<F> {}
|
||||
unsafe impl<F: Number> Sync for MutMatrix<F> {}
|
||||
|
||||
impl MutMatrix {
|
||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
||||
impl<F: Number> MutMatrix<F> {
|
||||
pub fn new(ptr: *mut F, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for MutMatrix {
|
||||
impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
@ -149,6 +245,22 @@ impl Matrix for MutMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col_mut(col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row_mut(row),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
false
|
||||
@ -160,64 +272,94 @@ impl Matrix for MutMatrix {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||
*self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
unsafe fn row(&self, row: usize) -> *const F {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
unsafe fn col(&self, col: usize) -> *const F {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr(&self) -> *const F {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_row(&mut self) {
|
||||
self.ptr = self.row_mut(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_col(&mut self) {
|
||||
self.ptr = self.col_mut(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_row(&mut self, rows: usize) {
|
||||
self.ptr = self.row_mut(rows);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_col(&mut self, cols: usize) {
|
||||
self.ptr = self.col_mut(cols);
|
||||
}
|
||||
}
|
||||
|
||||
impl MatrixMut for MutMatrix {
|
||||
impl<F: Number> MatrixMut<F> for MutMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
||||
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||
*self.ptr.add(row * self.stride + col) = value;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut F {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut F {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr_mut(&self) -> *mut F {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct MutTransposedMatrix {
|
||||
pub struct MutTransposedMatrix<F: Number> {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut f32
|
||||
pub ptr: *mut F
|
||||
}
|
||||
|
||||
unsafe impl Send for MutTransposedMatrix {}
|
||||
unsafe impl Sync for MutTransposedMatrix {}
|
||||
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
|
||||
unsafe impl<F: Number> Sync for MutTransposedMatrix<F> {}
|
||||
|
||||
impl MutTransposedMatrix {
|
||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
||||
impl<F: Number> MutTransposedMatrix<F> {
|
||||
pub fn new(ptr: *mut F, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for MutTransposedMatrix {
|
||||
impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
@ -226,6 +368,22 @@ impl Matrix for MutTransposedMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.col_mut(col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.row_mut(row),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
true
|
||||
@ -237,47 +395,77 @@ impl Matrix for MutTransposedMatrix {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||
*self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride + row) as *const f32
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||
self.ptr.add(col * self.stride + row) as *const F
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
self.ptr.add(row) as *const f32
|
||||
unsafe fn row(&self, row: usize) -> *const F {
|
||||
self.ptr.add(row) as *const F
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride) as *const f32
|
||||
unsafe fn col(&self, col: usize) -> *const F {
|
||||
self.ptr.add(col * self.stride) as *const F
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr(&self) -> *const F {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_row(&mut self) {
|
||||
self.ptr = self.row_mut(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn inc_col(&mut self) {
|
||||
self.ptr = self.col_mut(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_row(&mut self, rows: usize) {
|
||||
self.ptr = self.row_mut(rows);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn shift_col(&mut self, cols: usize) {
|
||||
self.ptr = self.col_mut(cols);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl MatrixMut for MutTransposedMatrix {
|
||||
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
||||
#[inline]
|
||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
||||
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||
*self.ptr.add(col * self.stride + row) = value;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
|
||||
self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut F {
|
||||
self.ptr.add(row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut F {
|
||||
self.ptr.add(col * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn ptr_mut(&self) -> *mut F {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
287
src/sgemm.rs
287
src/sgemm.rs
@ -1,273 +1,30 @@
|
||||
use crate::aligned_alloc;
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
use crate::kernel::params::single::*;
|
||||
use crate::kernel;
|
||||
|
||||
use crate::kernel::avx::AvxKernel;
|
||||
use crate::kernel::generic::GenericKernel;
|
||||
use crate::dim::*;
|
||||
use crate::executor::Executor;
|
||||
|
||||
pub(crate) unsafe fn sgemm<A, B, C>(
|
||||
pub unsafe fn sgemm<E: Executor>(
|
||||
e: &E,
|
||||
transa: bool,
|
||||
transb: bool,
|
||||
transc: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: C
|
||||
)
|
||||
where A: Matrix,
|
||||
B: Matrix,
|
||||
C: MatrixMut,
|
||||
{
|
||||
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<f32>());
|
||||
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<f32>());
|
||||
|
||||
for j in (0..n).step_by(NC) {
|
||||
let j_b = std::cmp::min(n - j, NC);
|
||||
for p in (0..k).step_by(KC) {
|
||||
let p_b = std::cmp::min(k - p, KC);
|
||||
for i in (0..m).step_by(MC) {
|
||||
let i_b = std::cmp::min(m - i, MC);
|
||||
|
||||
let pa = MutMatrix::new(packed_a.ptr_f32(), p_b);
|
||||
let pb = MutMatrix::new(packed_b.ptr_f32(), p_b);
|
||||
|
||||
inner_kernel(
|
||||
i_b,
|
||||
j_b,
|
||||
p_b,
|
||||
alpha,
|
||||
a.sub(p, i),
|
||||
b.sub(j, p),
|
||||
beta,
|
||||
c.sub(j, i),
|
||||
pa,
|
||||
pb,
|
||||
i == 0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// +----------------------+
|
||||
// |A : : : |B |
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// +----------------------+
|
||||
// |C | | | |D |
|
||||
// +----------------------+
|
||||
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_pa<A: Matrix>(a: A, pa: MutMatrix, i: usize) {
|
||||
if a.is_transposed() {
|
||||
kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
} else {
|
||||
kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_pb<B: Matrix>(b: B, pb: MutMatrix, j: usize) {
|
||||
if b.is_transposed() {
|
||||
kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
} else {
|
||||
kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_ukr<C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
pa: MutMatrix,
|
||||
pb: MutMatrix,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
kernel::sgemm_ukr(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_sup1<A: Matrix, C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
if a.is_transposed() {
|
||||
kernel::sgemm_sup1_t(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a.col(i),
|
||||
a.stride(),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
} else {
|
||||
kernel::sgemm_sup1(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a.col(i),
|
||||
a.stride(),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_sup0<B: Matrix, C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
pa: MutMatrix,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
if b.is_transposed() {
|
||||
kernel::sgemm_sup0_t(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
b.row(j),
|
||||
b.stride(),
|
||||
beta,
|
||||
c.index_mut(j, i))
|
||||
} else {
|
||||
kernel::sgemm_sup0(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
b.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn inner_kernel<A, B, C>(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
pa: MutMatrix,
|
||||
pb: MutMatrix,
|
||||
first_time: bool,
|
||||
)
|
||||
where A: Matrix,
|
||||
B: Matrix,
|
||||
C: MatrixMut,
|
||||
|
||||
{
|
||||
let n_left = n % NR;
|
||||
let n_main = n - n_left;
|
||||
|
||||
let m_left = m % MR;
|
||||
let m_main = m - m_left;
|
||||
|
||||
if first_time {
|
||||
for j in (0..n_main).step_by(NR) {
|
||||
sgemm_pb(b, pb, j);
|
||||
}
|
||||
}
|
||||
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_pa(a, pa, i);
|
||||
};
|
||||
|
||||
|
||||
for j in (0..n_main).step_by(NR) {
|
||||
|
||||
// Section A
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_ukr(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
pa,
|
||||
pb,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
|
||||
// Section B
|
||||
for i in m_main..m {
|
||||
sgemm_sup1(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
a,
|
||||
pb,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for j in n_main..n {
|
||||
|
||||
// Section C
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_sup0(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
pa,
|
||||
b,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
|
||||
// Section D
|
||||
for i in m_main..m {
|
||||
let mut elem = 0.0;
|
||||
|
||||
for p in 0..k {
|
||||
elem += a.get(p, i) * b.get(j, p);
|
||||
}
|
||||
|
||||
elem *= alpha;
|
||||
|
||||
if beta != 0.0 {
|
||||
elem += beta * c.get(j, i);
|
||||
}
|
||||
|
||||
c.set(j, i, elem);
|
||||
}
|
||||
}
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
|
||||
e,
|
||||
transa, transb, transc,
|
||||
m, n, k, alpha,
|
||||
a, lda, b, ldb,
|
||||
beta, c, ldc);
|
||||
}
|
308
src/test/l3s.rs
Normal file
308
src/test/l3s.rs
Normal file
@ -0,0 +1,308 @@
|
||||
use rand::Rng;
|
||||
|
||||
const M_LEN: usize = 8192;
|
||||
const N_LEN: usize = 8192;
|
||||
const K_LEN: usize = 8192;
|
||||
|
||||
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
|
||||
let mut a = vec![0.0; m * k];
|
||||
let mut a_t = vec![0.0; m * k];
|
||||
|
||||
let mut b = vec![0.0; n * k];
|
||||
let mut b_t = vec![0.0; n * k];
|
||||
|
||||
for row in 0..k {
|
||||
for col in 0..m {
|
||||
let v = rng.gen();
|
||||
a[row * m + col] = v;
|
||||
a_t[col * k + row] = v;
|
||||
}
|
||||
}
|
||||
|
||||
for row in 0..n {
|
||||
for col in 0..k {
|
||||
let v = rng.gen();
|
||||
b[row * k + col] = v;
|
||||
b_t[col * n + row] = v;
|
||||
}
|
||||
}
|
||||
|
||||
(a, a_t, b, b_t)
|
||||
}
|
||||
|
||||
|
||||
unsafe fn sgemm_ref(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
for i_m in 0..m {
|
||||
for i_k in 0..k {
|
||||
let a_part = *a.add(i_m * lda + i_k);
|
||||
|
||||
for i_n in 0..n {
|
||||
*c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for j in 0..n {
|
||||
// for i in 0..m {
|
||||
// let mut ci = *c.add(i + j * ldc);
|
||||
|
||||
// for p in 0..k {
|
||||
// ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
|
||||
// }
|
||||
|
||||
// *c.add(i + j * ldc) = ci;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_nn() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, _, b, _) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
unsafe {
|
||||
blas::sgemm(
|
||||
b'N',
|
||||
b'N',
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
1.0,
|
||||
a.as_slice(),
|
||||
m as i32,
|
||||
b.as_slice(),
|
||||
k as i32,
|
||||
0.0,
|
||||
cref.as_mut_slice(),
|
||||
m as i32,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe {
|
||||
crate::sgemm(
|
||||
&crate::executor::DefaultExecutor,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
b.as_ptr(),
|
||||
k,
|
||||
0.0,
|
||||
c.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_nt() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, _, b, b_t) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
unsafe {
|
||||
blas::sgemm(
|
||||
b'N',
|
||||
b'T',
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
1.0,
|
||||
a.as_slice(),
|
||||
m as i32,
|
||||
b_t.as_slice(),
|
||||
n as i32,
|
||||
0.0,
|
||||
cref.as_mut_slice(),
|
||||
m as i32,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe {
|
||||
crate::sgemm(
|
||||
&crate::executor::DefaultExecutor,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
0.0,
|
||||
c.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_tn() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, a_t, b, _) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
unsafe {
|
||||
blas::sgemm(
|
||||
b'T',
|
||||
b'N',
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
1.0,
|
||||
a_t.as_slice(),
|
||||
k as i32,
|
||||
b.as_slice(),
|
||||
k as i32,
|
||||
0.0,
|
||||
cref.as_mut_slice(),
|
||||
m as i32,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe {
|
||||
crate::sgemm(
|
||||
&crate::executor::DefaultExecutor,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
b.as_ptr(),
|
||||
k,
|
||||
0.0,
|
||||
c.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_tt() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
let (a, a_t, b, b_t) = make_matrices();
|
||||
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
unsafe {
|
||||
blas::sgemm(
|
||||
b'T',
|
||||
b'T',
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
1.0,
|
||||
a_t.as_slice(),
|
||||
k as i32,
|
||||
b_t.as_slice(),
|
||||
n as i32,
|
||||
0.0,
|
||||
cref.as_mut_slice(),
|
||||
m as i32,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe {
|
||||
crate::sgemm(
|
||||
&crate::executor::DefaultExecutor,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
0.0,
|
||||
c.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
let (a, b) = (c[index], cref[index]);
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn feq(a: f32, b: f32) -> bool {
|
||||
if a == b {
|
||||
true
|
||||
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
|
||||
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
|
||||
} else {
|
||||
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
||||
}
|
||||
}
|
1
src/test/mod.rs
Normal file
1
src/test/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
mod l3s;
|
2
stats.sh
2
stats.sh
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
|
||||
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
|
||||
cargo clean && cargo build --release
|
||||
#cargo clean && cargo build --release
|
||||
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm
|
||||
|
Loading…
Reference in New Issue
Block a user