diff --git a/Cargo.lock b/Cargo.lock index 42794ef..0572918 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,14 +36,66 @@ name = "cfg-if" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "crossbeam-deque" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-queue" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-utils" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "either" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "gemm" version = "0.1.0" dependencies = [ "blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)", - "matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", "openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -56,17 +108,30 @@ dependencies = [ "wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "hermit-abi" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "libc" version = "0.2.65" source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] -name = "matrixmultiply" -version = "0.2.3" +name = "memoffset" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -86,6 +151,15 @@ dependencies = [ "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "num_cpus" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "openblas" version = "0.0.1" @@ -134,9 +208,60 @@ dependencies = [ ] [[package]] -name = "rawpointer" -version = "0.2.1" +name = "rayon" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rayon-core" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rustc_version" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "scopeguard" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "threadpool" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "wasi" @@ -149,16 +274,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440" "checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb" "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca" +"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac" +"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700" +"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4" +"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" "checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407" +"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120" +"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8" -"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f" +"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9" "checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc" "checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4" +"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" "checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" "checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" "checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd" +"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791" +"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" +"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" +"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865" "checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d" diff --git a/Cargo.toml b/Cargo.toml index 9f9fe11..086c320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,11 @@ build = "build.rs" [dependencies] rand = "0.7.2" -matrixmultiply = "0.2.3" blas = "0.20.0" openblas = "0.0.1" +crunchy = "0.2.2" +rayon = "1.2.1" +threadpool = "1.7.1" + +[profile.release] +debug = true diff --git a/src/aligned_alloc.rs b/src/aligned_alloc.rs index 16edd94..dd912a2 100644 --- a/src/aligned_alloc.rs +++ b/src/aligned_alloc.rs @@ -13,12 +13,8 @@ impl Alloc { Alloc { ptr, layout } } - pub fn ptr_f32(&self) -> *mut f32 { - self.ptr as *mut f32 - } - - pub fn ptr(&self) -> *mut u8 { - self.ptr + pub fn ptr(&self) -> *mut F { + self.ptr as *mut F } } diff --git a/src/dim.rs b/src/dim.rs new file mode 100644 index 0000000..36a9c4b --- /dev/null +++ b/src/dim.rs @@ -0,0 +1,75 @@ +pub trait Dim { + const DIM: usize; +} + +pub struct A1; +impl Dim for A1 { + const DIM: usize = 1; +} +pub struct A2; +impl Dim for A2 { + const DIM: usize = 2; +} +pub struct A3; +impl Dim for A3 { + const DIM: usize = 3; +} +pub struct A4; +impl Dim for A4 { + const DIM: usize = 4; +} +pub struct A5; +impl Dim for A5 { + const DIM: usize = 5; +} +pub struct A6; +impl Dim for A6 { + const DIM: usize = 6; +} +pub struct A7; +impl Dim for A7 { + const DIM: usize = 7; +} + +pub struct A8; +impl Dim for A8 { + const DIM: usize = 8; +} + +pub struct A16; +impl Dim for A16 { + const DIM: usize = 16; +} + +/* +macro_rules! gen { + (($name: ident, $val: lit)) => { + + }; +} + +gen! { + (A0, 0), + (A1, 1), + (A2, 2), + (A3, 3), + (A4, 4), + (A5, 5), + (A6, 6), + (A7, 7), + (A8, 8), + (A9, 9), + (A10, 10), + (A11, 11), + (A12, 12), + (A14, 14), + (A16, 16), + (A18, 18), + (A20, 20), + (A22, 22), + (A24, 24), + (A26, 26), + (A28, 28), + (A30, 30), + (A32, 32), +}*/ \ No newline at end of file diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..4ee1c70 --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,111 @@ +pub trait Executor { + fn execute( + &self, + start: usize, + end: usize, + step: usize, + f: F, + ); + + fn synchronize(&self); +} + +pub struct DefaultExecutor; +impl Executor for DefaultExecutor { + #[inline] + fn execute( + &self, + start: usize, + end: usize, + step: usize, + f: F, + ) { + for i in (start..end).step_by(step) { + f(i); + } + } + + fn synchronize(&self) { + + } +} + +pub struct RayonExecutor; +impl Executor for RayonExecutor { + #[inline] + fn execute( + &self, + start: usize, + end: usize, + step: usize, + f: F, + ) { + let end = 1 + ((end - 1) / step); + + use rayon::prelude::*; + (start..end) + .into_par_iter() + .map(|x| x * step) + .for_each(|x|f(x)); + } + + fn synchronize(&self) { + + } +} + + +pub struct ThreadPoolExecutor { + thread_pool: threadpool::ThreadPool +} + +impl ThreadPoolExecutor { + pub fn new() -> ThreadPoolExecutor { + ThreadPoolExecutor { + thread_pool: threadpool::Builder::new().build() + } + } +} + +impl Executor for ThreadPoolExecutor { + #[inline] + fn execute( + &self, + start: usize, + end: usize, + step: usize, + f: F, + ) { + let thread_count = self.thread_pool.max_count(); + + let len = end - start; + let num_steps = len / step; + + let mut left_steps = num_steps % thread_count; + let main_steps = num_steps - left_steps; + + let job_size = main_steps / thread_count; + + let mut prev_end = 0; + + for _ in 0..thread_count { + let mut now_end = prev_end + job_size; + if left_steps > 0 { + now_end += 1; + left_steps -= 1; + } + self.thread_pool.execute(move || { + for j in prev_end..now_end { + f(start + j * step); + } + }); + + prev_end = now_end + } + } + + #[inline] + fn synchronize(&self) { + self.thread_pool.join(); + } +} diff --git a/src/gemm.rs b/src/gemm.rs new file mode 100644 index 0000000..9d7516e --- /dev/null +++ b/src/gemm.rs @@ -0,0 +1,242 @@ +use crate::aligned_alloc; +use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix}; +use crate::kernel::params::single::*; +use crate::kernel; +use crate::kernel::GemmKernel; +use crate::dim::Dim; +use crate::executor::Executor; + +pub unsafe fn gemm( + e: &E, + transa: bool, + transb: bool, + transc: bool, + m: usize, + n: usize, + k: usize, + alpha: F, + a: *const F, + lda: usize, + b: *const F, + ldb: usize, + beta: F, + c: *mut F, + ldc: usize, +) +where E: Executor, + F: Number, + MR: Dim, NR: Dim, + K: GemmKernel, +{ + match (transa, transb, transc) { + (false, false, false) => gemm_template::( + e, m, n, k, alpha, + ConstMatrix::new(a, lda), + ConstMatrix::new(b, ldb), + beta, + MutMatrix::new(c, ldc)), + + (false, false, true) => gemm_template::( + e, m, n, k, alpha, + ConstMatrix::new(a, lda), + ConstMatrix::new(b, ldb), + beta, + MutTransposedMatrix::new(c, ldc)), + + (false, true, false) => gemm_template::( + e, m, n, k, alpha, + ConstMatrix::new(a, lda), + ConstTransposedMatrix::new(b, ldb), + beta, + MutMatrix::new(c, ldc)), + + (false, true, true) => gemm_template::( + e, m, n, k, alpha, + ConstMatrix::new(a, lda), + ConstTransposedMatrix::new(b, ldb), + beta, + MutTransposedMatrix::new(c, ldc)), + + (true, false, false) => gemm_template::( + e, m, n, k, alpha, + ConstTransposedMatrix::new(a, lda), + ConstMatrix::new(b, ldb), + beta, + MutMatrix::new(c, ldc)), + + (true, false, true) => gemm_template::( + e, m, n, k, alpha, + ConstTransposedMatrix::new(a, lda), + ConstMatrix::new(b, ldb), + beta, + MutTransposedMatrix::new(c, ldc)), + + + (true, true, false) => gemm_template::( + e, m, n, k, alpha, + ConstTransposedMatrix::new(a, lda), + ConstTransposedMatrix::new(b, ldb), + beta, + MutMatrix::new(c, ldc)), + + + (true, true, true) => gemm_template::( + e, m, n, k, alpha, + ConstTransposedMatrix::new(a, lda), + ConstTransposedMatrix::new(b, ldb), + beta, + MutTransposedMatrix::new(c, ldc)), + } +} + +unsafe fn gemm_template( + e: &E, + m: usize, + n: usize, + k: usize, + alpha: F, + a: A, + b: B, + beta: F, + c: C +) +where E: Executor, + F: Number, + MR: Dim, NR: Dim, + K: GemmKernel, + A: Matrix, + B: Matrix, + C: MatrixMut, +{ + let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::()); + let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::()); + + for j in (0..n).step_by(NC) { + let j_b = std::cmp::min(n - j, NC); + for p in (0..k).step_by(KC) { + let p_b = std::cmp::min(k - p, KC); + for i in (0..m).step_by(MC) { + let i_b = std::cmp::min(m - i, MC); + + let pa = MutMatrix::new(packed_a.ptr::(), p_b); + let pb = MutMatrix::new(packed_b.ptr::(), p_b); + + inner_kernel::( + e, + i_b, + j_b, + p_b, + alpha, + a.sub(p, i), + b.sub(j, p), + beta, + c.sub(j, i), + pa, + pb, + i == 0 + ); + } + } + } +} + +// | MR | +// +----------------------+ +// |TL : : : |TR| +// | : : : | | +// + - - - - - - - - - ---+---- +// | : : : | | NR +// | : : : | | +// + - - - - - - - - - ---+---- +// | : : : | | +// | : : : | | +// +----------------------+ +// |BL | | | |BR| +// +----------------------+ +unsafe fn inner_kernel( + e: &E, + m: usize, + n: usize, + k: usize, + alpha: F, + a: A, + b: B, + beta: F, + c: C, + pa: MutMatrix, + pb: MutMatrix, + first_time: bool, +) + where E: Executor, + F: Number, + MR: Dim, + NR: Dim, + K: kernel::GemmKernel, + A: Matrix, + B: Matrix, + C: MatrixMut, + +{ + let n_left = n % NR; + let n_main = n - n_left; + + let m_left = m % MR; + let m_main = m - m_left; + + if first_time { + e.execute(0, n_main, NR, move |j| + K::pack_row_b(b, pb, j)); + } + + e.execute(0, m_main, MR, move |i| + K::pack_row_a(a, pa, i)); + + e.synchronize(); + + e.execute(0, n_main, NR, move |j| { + // Section TL + for i in (0..m_main).step_by(MR) { + K::main_tl(alpha, + pa.sub_row(i), + pb.sub_row(j), + beta, + c.sub(j, i)); + } + + // Section TR + for i in m_main..m { + K::sup_tr( + alpha, + a.sub_col(i), + pb.sub_row(j), + beta, + c.sub(j, i)); + } + }); + + e.execute(n_main, n, 1, move |j| { + // Section BL + for i in (0..m_main).step_by(MR) { + K::sup_bl( + alpha, + pa.sub_row(i), + b.sub_row(j), + beta, + c.sub(j, i) + ); + } + + // Section BR + for i in m_main..m { + K::sup_br( + k, + alpha, + a.sub_col(i), + b.sub_row(j), + beta, + c.sub(j, i)) + } + }); + + e.synchronize(); +} \ No newline at end of file diff --git a/src/kernel/avx/l1s.rs b/src/kernel/avx/l1s.rs index 7384396..e315ade 100644 --- a/src/kernel/avx/l1s.rs +++ b/src/kernel/avx/l1s.rs @@ -246,6 +246,7 @@ pub unsafe fn sdot( let mut acc1 = _mm256_setzero_ps(); let mut acc2 = _mm256_setzero_ps(); let mut acc3 = _mm256_setzero_ps(); + for _ in 0..n / 32 { let x0 = _mm256_loadu_ps(x); let y0 = _mm256_loadu_ps(y); diff --git a/src/kernel/avx/l3s.rs b/src/kernel/avx/l3s.rs index 0fbc586..3f0e1a5 100644 --- a/src/kernel/avx/l3s.rs +++ b/src/kernel/avx/l3s.rs @@ -1,29 +1,82 @@ use super::fma::fmadd_ps; use super::intrinsics::*; +use crate::kernel::params::single::{MR, NR}; +use crate::matrix::{Matrix, MutMatrix, MatrixMut}; +use crunchy::unroll; -pub(crate) unsafe fn sgemm_ukr_16x4( +#[inline] +pub(crate) unsafe fn sgemm_ukr_16x8>( k: usize, alpha: f32, - pa: *const f32, - pb: *const f32, + pa: MutMatrix, + pb: MutMatrix, beta: f32, - c: *mut f32, - ldc: usize, + c: C, ) { let mut mt00 = _mm256_setzero_ps(); let mut mt01 = _mm256_setzero_ps(); let mut mt02 = _mm256_setzero_ps(); let mut mt03 = _mm256_setzero_ps(); - + let mut mt04 = _mm256_setzero_ps(); + // let mut mt05 = _mm256_setzero_ps(); + // let mut mt06 = _mm256_setzero_ps(); + // let mut mt07 = _mm256_setzero_ps(); + let mut mt10 = _mm256_setzero_ps(); let mut mt11 = _mm256_setzero_ps(); let mut mt12 = _mm256_setzero_ps(); let mut mt13 = _mm256_setzero_ps(); + let mut mt14 = _mm256_setzero_ps(); + // let mut mt15 = _mm256_setzero_ps(); + // let mut mt16 = _mm256_setzero_ps(); + // let mut mt17 = _mm256_setzero_ps(); + + let mut pa = pa.ptr(); + let mut pb = pb.ptr(); - let mut pa = pa; - let mut pb = pb; + const BATCH: usize = 16; - for _ in 0..k { + let k_right = k % BATCH; + let k_main = k - k_right; + + for _ in (0..k_main).step_by(BATCH) { + unroll! { + for i in 0..16 { + let a0 = _mm256_load_ps(pa.add(i * MR)); + let a1 = _mm256_load_ps(pa.add(i * MR + 8)); + + let b0 = _mm256_broadcast_ss(&*pb.add(i * NR)); + let b1 = _mm256_broadcast_ss(&*pb.add(i * NR + 1)); + let b2 = _mm256_broadcast_ss(&*pb.add(i * NR + 2)); + let b3 = _mm256_broadcast_ss(&*pb.add(i * NR + 3)); + let b4 = _mm256_broadcast_ss(&*pb.add(i * NR + 4)); + // let b5 = _mm256_broadcast_ss(&*pb.add(5)); + // let b6 = _mm256_broadcast_ss(&*pb.add(6)); + // let b7 = _mm256_broadcast_ss(&*pb.add(7)); + + mt00 = fmadd_ps(a0, b0, mt00); + mt01 = fmadd_ps(a0, b1, mt01); + mt02 = fmadd_ps(a0, b2, mt02); + mt03 = fmadd_ps(a0, b3, mt03); + mt04 = fmadd_ps(a0, b4, mt04); + // mt05 = fmadd_ps(a0, b5, mt05); + // mt06 = fmadd_ps(a0, b6, mt06); + // mt07 = fmadd_ps(a0, b7, mt07); + + mt10 = fmadd_ps(a1, b0, mt10); + mt11 = fmadd_ps(a1, b1, mt11); + mt12 = fmadd_ps(a1, b2, mt12); + mt13 = fmadd_ps(a1, b3, mt13); + mt14 = fmadd_ps(a1, b4, mt14); + } + } + + pa = pa.add(BATCH * MR); + pb = pb.add(BATCH * NR); + } + + + for _ in k_main..k { let a0 = _mm256_load_ps(pa); let a1 = _mm256_load_ps(pa.add(8)); @@ -31,66 +84,107 @@ pub(crate) unsafe fn sgemm_ukr_16x4( let b1 = _mm256_broadcast_ss(&*pb.add(1)); let b2 = _mm256_broadcast_ss(&*pb.add(2)); let b3 = _mm256_broadcast_ss(&*pb.add(3)); - + let b4 = _mm256_broadcast_ss(&*pb.add(4)); + // let b5 = _mm256_broadcast_ss(&*pb.add(5)); + // let b6 = _mm256_broadcast_ss(&*pb.add(6)); + // let b7 = _mm256_broadcast_ss(&*pb.add(7)); + mt00 = fmadd_ps(a0, b0, mt00); mt01 = fmadd_ps(a0, b1, mt01); mt02 = fmadd_ps(a0, b2, mt02); mt03 = fmadd_ps(a0, b3, mt03); + mt04 = fmadd_ps(a0, b4, mt04); + // mt05 = fmadd_ps(a0, b5, mt05); + // mt06 = fmadd_ps(a0, b6, mt06); + // mt07 = fmadd_ps(a0, b7, mt07); + mt10 = fmadd_ps(a1, b0, mt10); mt11 = fmadd_ps(a1, b1, mt11); mt12 = fmadd_ps(a1, b2, mt12); mt13 = fmadd_ps(a1, b3, mt13); + mt14 = fmadd_ps(a1, b4, mt14); - pa = pa.add(16); - pb = pb.add(4); + pa = pa.add(MR); + pb = pb.add(NR); } - let alpha = _mm256_broadcast_ss(&alpha); + // let alpha = _mm256_broadcast_ss(&alpha); - mt00 = _mm256_mul_ps(alpha, mt00); - mt01 = _mm256_mul_ps(alpha, mt01); - mt02 = _mm256_mul_ps(alpha, mt02); - mt03 = _mm256_mul_ps(alpha, mt03); - mt10 = _mm256_mul_ps(alpha, mt10); - mt11 = _mm256_mul_ps(alpha, mt11); - mt12 = _mm256_mul_ps(alpha, mt12); - mt13 = _mm256_mul_ps(alpha, mt13); + // mt00 = _mm256_mul_ps(alpha, mt00); + // mt01 = _mm256_mul_ps(alpha, mt01); + // mt02 = _mm256_mul_ps(alpha, mt02); + // mt03 = _mm256_mul_ps(alpha, mt03); + // mt04 = _mm256_mul_ps(alpha, mt04); + // // mt05 = _mm256_mul_ps(alpha, mt05); + // // mt06 = _mm256_mul_ps(alpha, mt06); + // // mt07 = _mm256_mul_ps(alpha, mt07); - let ccol0 = c; - let ccol1 = c.add(ldc); - let ccol2 = c.add(ldc * 2); - let ccol3 = c.add(ldc * 3); + // mt10 = _mm256_mul_ps(alpha, mt10); + // mt11 = _mm256_mul_ps(alpha, mt11); + // mt12 = _mm256_mul_ps(alpha, mt12); + // mt13 = _mm256_mul_ps(alpha, mt13); + // mt14 = _mm256_mul_ps(alpha, mt14); + // // mt15 = _mm256_mul_ps(alpha, mt15); + // // mt16 = _mm256_mul_ps(alpha, mt16); + // // mt17 = _mm256_mul_ps(alpha, mt17); - if beta != 0.0 { - let beta = _mm256_broadcast_ss(&beta); + let ccol0 = c.ptr_mut(); + let ccol1 = c.row_mut(1); + let ccol2 = c.row_mut(2); + let ccol3 = c.row_mut(3); + let ccol4 = c.row_mut(4); + // let ccol5 = c.row_mut(5); + // let ccol6 = c.row_mut(6); + // let ccol7 = c.add(ldc * 7); + + // if beta != 0.0 { + // let beta = _mm256_broadcast_ss(&beta); - mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00); - mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01); - mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02); - mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03); - mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10); - mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11); - mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12); - mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13); - } + // mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00); + // mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01); + // mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02); + // mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03); + // mt04 = fmadd_ps(beta, _mm256_loadu_ps(ccol4), mt04); + // // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05); + // // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06); + // // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07); + + // mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10); + // mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11); + // mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12); + // mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13); + // mt14 = fmadd_ps(beta, _mm256_loadu_ps(ccol4.add(8)), mt14); + // // mt15 = fmadd_ps(beta, _mm256_loadu_ps(ccol5.add(8)), mt15); + // // mt16 = fmadd_ps(beta, _mm256_loadu_ps(ccol6.add(8)), mt16); + // // mt17 = fmadd_ps(beta, _mm256_loadu_ps(ccol7.add(8)), mt17); + // } _mm256_storeu_ps(ccol0, mt00); _mm256_storeu_ps(ccol1, mt01); _mm256_storeu_ps(ccol2, mt02); _mm256_storeu_ps(ccol3, mt03); + _mm256_storeu_ps(ccol4, mt04); + // _mm256_storeu_ps(ccol5, mt05); + // _mm256_storeu_ps(ccol6, mt06); + // _mm256_storeu_ps(ccol7, mt07); + _mm256_storeu_ps(ccol0.add(8), mt10); _mm256_storeu_ps(ccol1.add(8), mt11); _mm256_storeu_ps(ccol2.add(8), mt12); _mm256_storeu_ps(ccol3.add(8), mt13); + _mm256_storeu_ps(ccol4.add(8), mt14); + // _mm256_storeu_ps(ccol5.add(8), mt15); + // _mm256_storeu_ps(ccol6.add(8), mt16); + // _mm256_storeu_ps(ccol7.add(8), mt17); } -pub(crate) unsafe fn sgemm_sup_16x1( +pub(crate) unsafe fn sgemm_sup_16x1, C: MatrixMut>( k: usize, alpha: f32, - pa: *const f32, - b: *const f32, + pa: MutMatrix, + b: B, beta: f32, - c: *mut f32, + c: C, ) { let mut mt0 = _mm256_setzero_ps(); let mut mt1 = _mm256_setzero_ps(); @@ -99,76 +193,32 @@ pub(crate) unsafe fn sgemm_sup_16x1( let mut b = b; for _ in 0..k { - let a0 = _mm256_load_ps(pa); - let a1 = _mm256_load_ps(pa.add(8)); + let a0 = _mm256_load_ps(pa.ptr()); + let a1 = _mm256_load_ps(pa.col(8)); - let b0 = _mm256_broadcast_ss(&*b); + let b0 = _mm256_broadcast_ss(&*b.ptr()); mt0 = fmadd_ps(a0, b0, mt0); mt1 = fmadd_ps(a1, b0, mt1); - pa = pa.add(16); - b = b.add(1); + pa.shift_col(16); + b.inc_col(); } - let alpha = _mm256_broadcast_ss(&alpha); + // let alpha = _mm256_broadcast_ss(&alpha); - mt0 = _mm256_mul_ps(alpha, mt0); - mt1 = _mm256_mul_ps(alpha, mt1); + // mt0 = _mm256_mul_ps(alpha, mt0); + // mt1 = _mm256_mul_ps(alpha, mt1); - if beta != 0.0 { - let beta = _mm256_broadcast_ss(&beta); + // if beta != 0.0 { + // let beta = _mm256_broadcast_ss(&beta); - mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0); - mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1); - } + // mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0); + // mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1); + // } - _mm256_storeu_ps(c, mt0); - _mm256_storeu_ps(c.add(8), mt1); -} - -pub(crate) unsafe fn sgemm_sup0_t( - k: usize, - alpha: f32, - pa: *const f32, - b: *const f32, - ldb: usize, - beta: f32, - c: *mut f32, -) { - let mut mt0 = _mm256_setzero_ps(); - let mut mt1 = _mm256_setzero_ps(); - - let mut pa = pa; - let mut b = b; - - for _ in 0..k { - let a0 = _mm256_load_ps(pa); - let a1 = _mm256_load_ps(pa.add(8)); - - let b0 = _mm256_broadcast_ss(&*b); - - mt0 = fmadd_ps(a0, b0, mt0); - mt1 = fmadd_ps(a1, b0, mt1); - - pa = pa.add(16); - b = b.add(ldb); - } - - let alpha = _mm256_broadcast_ss(&alpha); - - mt0 = _mm256_mul_ps(alpha, mt0); - mt1 = _mm256_mul_ps(alpha, mt1); - - if beta != 0.0 { - let beta = _mm256_broadcast_ss(&beta); - - mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0); - mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1); - } - - _mm256_storeu_ps(c, mt0); - _mm256_storeu_ps(c.add(8), mt1); + _mm256_storeu_ps(c.ptr_mut(), mt0); + _mm256_storeu_ps(c.col_mut(8), mt1); } pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) { diff --git a/src/kernel/avx/mod.rs b/src/kernel/avx/mod.rs index 788e110..2e4c6f3 100644 --- a/src/kernel/avx/mod.rs +++ b/src/kernel/avx/mod.rs @@ -5,3 +5,85 @@ pub mod l1d; pub mod l1s; pub mod l3d; pub mod l3s; + +use core::marker::PhantomData; +use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut}; +use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; +use crate::dim::*; + +pub struct AvxKernel(PhantomData); + +impl GemmKernelSupNr for AvxKernel + where I: GemmKernelSupNr +{ + #[inline] + unsafe fn sup_tr, C: MatrixMut>( + alpha: f32, + a: A, + pb: MutMatrix, + beta: f32, + c: C, + ) { + I::sup_tr(alpha, a, pb, beta, c); + } +} + +impl GemmKernelSupMr for AvxKernel + where I: GemmKernelSupMr +{ + #[inline] + unsafe fn sup_bl, C: MatrixMut>( + alpha: f32, + pa: MutMatrix, + b: B, + beta: f32, + c: C, + ) { + self::l3s::sgemm_sup_16x1(pa.stride, alpha, pa, b, beta, c); + } +} + +impl GemmKernelSup for AvxKernel + where I: GemmKernelSup +{ + #[inline] + unsafe fn sup_br, B: Matrix, C: MatrixMut>( + k: usize, + alpha: f32, + a: A, + b: B, + beta: f32, + c: C, + ) { + I::sup_br(k, alpha, a, b, beta, c); + } +} + +impl GemmKernel for AvxKernel + where I: GemmKernel +{ + #[inline] + unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { + if a.is_transposed() { + I::pack_row_a(a, pa, i); + } else { + self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); + } + } + + #[inline] + unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { + I::pack_row_b(b, pb, j); + } + + #[inline] + unsafe fn main_tl>( + alpha: f32, + pa: MutMatrix, + pb: MutMatrix, + beta: f32, + c: C, + ) { + self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c); + } +} \ No newline at end of file diff --git a/src/kernel/generic/l3s.rs b/src/kernel/generic/l3s.rs index 61ca3d0..8dc7555 100644 --- a/src/kernel/generic/l3s.rs +++ b/src/kernel/generic/l3s.rs @@ -1,114 +1,91 @@ -pub(crate) unsafe fn sgemm_sup_1x4( +use crate::kernel::params::single::NR; +use crate::matrix::{MutMatrix, Matrix, MatrixMut}; + + +pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( k: usize, alpha: f32, - a: *const f32, - lda: usize, - pb: *const f32, + a: A, + pb: MutMatrix, beta: f32, - c: *mut f32, - ldc: usize, + c: C, ) { - let mut c0 = 0.0; - let mut c1 = 0.0; - let mut c2 = 0.0; - let mut c3 = 0.0; + let mut c0 = 0.0f32; + let mut c1 = 0.0f32; + let mut c2 = 0.0f32; + let mut c3 = 0.0f32; + let mut c4 = 0.0f32; + // let mut c5 = 0.0f32; + // let mut c6 = 0.0f32; + // let mut c7 = 0.0f32; let mut a = a; let mut pb = pb; for _ in 0..k { - let a0 = *a; + let a0 = *a.ptr(); - c0 += *pb * a0; - c1 += *pb.add(1) * a0; - c2 += *pb.add(2) * a0; - c3 += *pb.add(3) * a0; + c0 += *pb.ptr() * a0; + c1 += *pb.col(1) * a0; + c2 += *pb.col(2) * a0; + c3 += *pb.col(3) * a0; + c4 += *pb.col(4) * a0; + // c5 += *pb.col(5) * a0; + // c6 += *pb.col(6) * a0; + // c7 += *pb.col(7) * a0; - a = a.add(lda); - pb = pb.add(4); + a.inc_row(); + pb.shift_col(NR); } - c0 *= alpha; - c1 *= alpha; - c2 *= alpha; - c3 *= alpha; + // c0 *= alpha; + // c1 *= alpha; + // c2 *= alpha; + // c3 *= alpha; + // c4 *= alpha; + // // c5 *= alpha; + // // c6 *= alpha; + // // c7 *= alpha; + + let ccol0 = c.ptr_mut(); + let ccol1 = c.row_mut(1); + let ccol2 = c.row_mut(2); + let ccol3 = c.row_mut(3); + let ccol4 = c.row_mut(4); + // let ccol5 = c.row_mut(5); + // let ccol6 = c.row_mut(6); + // let ccol7 = c.add(ldc * 7); - let ccol0 = c; - let ccol1 = c.add(ldc); - let ccol2 = c.add(ldc * 2); - let ccol3 = c.add(ldc * 3); - - if beta != 0.0 { - c0 += beta * *ccol0; - c1 += beta * *ccol1; - c2 += beta * *ccol2; - c3 += beta * *ccol3; - } + // if beta != 0.0 { + // c0 += beta * *ccol0; + // c1 += beta * *ccol1; + // c2 += beta * *ccol2; + // c3 += beta * *ccol3; + // c4 += beta * *ccol4; + // // c5 += beta * *ccol5; + // // c6 += beta * *ccol6; + // // c7 += beta * *ccol7; + // } *ccol0 = c0; *ccol1 = c1; *ccol2 = c2; *ccol3 = c3; + *ccol4 = c4; + // *ccol5 = c5; + // *ccol6 = c6; + // *ccol7 = c7; } -pub(crate) unsafe fn sgemm_sup1_t( - k: usize, - alpha: f32, - a: *const f32, - lda: usize, - pb: *const f32, - beta: f32, - c: *mut f32, - ldc: usize, -) { - let mut c0 = 0.0; - let mut c1 = 0.0; - let mut c2 = 0.0; - let mut c3 = 0.0; - - let mut a = a; - let mut pb = pb; - - for _ in 0..k { - let a0 = *a; - - c0 += *pb * a0; - c1 += *pb.add(1) * a0; - c2 += *pb.add(2) * a0; - c3 += *pb.add(3) * a0; - - a = a.add(1); - pb = pb.add(4); - } - - c0 *= alpha; - c1 *= alpha; - c2 *= alpha; - c3 *= alpha; - - let ccol0 = c; - let ccol1 = c.add(ldc); - let ccol2 = c.add(ldc * 2); - let ccol3 = c.add(ldc * 3); - - if beta != 0.0 { - c0 += beta * *ccol0; - c1 += beta * *ccol1; - c2 += beta * *ccol2; - c3 += beta * *ccol3; - } - - *ccol0 = c0; - *ccol1 = c1; - *ccol2 = c2; - *ccol3 = c3; -} - -pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) { +pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) { let mut bcol0 = b; let mut bcol1 = b.add(ldb); let mut bcol2 = b.add(ldb * 2); let mut bcol3 = b.add(ldb * 3); + let mut bcol4 = b.add(ldb * 4); + // let mut bcol5 = b.add(ldb * 5); + // let mut bcol6 = b.add(ldb * 6); + // let mut bcol7 = b.row(7); let mut pb = pb; @@ -117,27 +94,24 @@ pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f *pb.add(1) = *bcol1; *pb.add(2) = *bcol2; *pb.add(3) = *bcol3; + *pb.add(4) = *bcol4; + // *pb.add(5) = *bcol5; + // *pb.add(6) = *bcol6; + // *pb.col(7) = *bcol7; bcol0 = bcol0.add(1); bcol1 = bcol1.add(1); bcol2 = bcol2.add(1); bcol3 = bcol3.add(1); - pb = pb.add(4); + bcol4 = bcol4.add(1); + // bcol5 = bcol5.add(1); + // bcol6 = bcol6.add(1); + // bcol7 = bcol7.add(1); + + pb = pb.add(NR); } } -// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) { -// use crate::kernel::params::single::MR; -// let mut a = a; -// let mut pa = pa; - -// for p in 0..k { -// for j in 0..MR { -// *pa.add(j * k + p) = *a.add(p * lda + j); -// } -// } -// } - pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) { use crate::kernel::params::single::NR; let mut b = b; diff --git a/src/kernel/generic/mod.rs b/src/kernel/generic/mod.rs index 0df46f7..ce7ee60 100644 --- a/src/kernel/generic/mod.rs +++ b/src/kernel/generic/mod.rs @@ -1,2 +1,98 @@ pub mod l3d; pub mod l3s; + +use crate::matrix::{Matrix, MatrixMut, MutMatrix}; +use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; +use crate::dim::*; + +pub struct GenericKernel; + +impl GemmKernelSupNr for GenericKernel { + #[inline] + unsafe fn sup_tr, C: MatrixMut>( + alpha: f32, + a: A, + pb: MutMatrix, + beta: f32, + c: C, + ) { + self::l3s::sgemm_sup_1x8( + pb.stride, + alpha, + a, + pb, + beta, + c) + } +} + +impl GemmKernelSupMr for GenericKernel { + #[inline] + unsafe fn sup_bl, C: MatrixMut>( + alpha: f32, + pa: MutMatrix, + b: B, + beta: f32, + c: C, + ) { + unimplemented!() + } +} + +impl GemmKernelSup for GenericKernel { + #[inline] + unsafe fn sup_br, B: Matrix, C: MatrixMut>( + k: usize, + alpha: f32, + a: A, + b: B, + beta: f32, + c: C, + ) { + let mut elem = 0.0; + + for p in 0..k { + elem += *a.row(p) * *b.col(p); + } + + elem *= alpha; + + if beta != 0.0 { + elem += beta * *c.ptr(); + } + + *c.ptr_mut() = elem; + } +} + +impl GemmKernel for GenericKernel { + + #[inline] + unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { + if a.is_transposed() { + self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); + } else { + unimplemented!() + } + } + + #[inline] + unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { + if b.is_transposed() { + self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); + } else { + self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); + } + } + + #[inline] + unsafe fn main_tl>( + alpha: f32, + pa: MutMatrix, + pb: MutMatrix, + beta: f32, + c: C, + ) { + unimplemented!() + } +} \ No newline at end of file diff --git a/src/kernel/mod.rs b/src/kernel/mod.rs index b8fb8bd..c12b9af 100644 --- a/src/kernel/mod.rs +++ b/src/kernel/mod.rs @@ -2,53 +2,20 @@ any(target_arch = "x86_64", target_arch = "x86"), target_feature = "avx" ))] -mod avx; - -#[cfg(all( - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "avx" -))] -pub use avx::{l1d::*, l1s::*}; - -#[cfg(all( - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "avx" -))] -pub(crate) use avx::l3s::{ - sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t, -}; - -#[cfg(all( - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "avx" -))] -pub(crate) use avx::l3d::{ - dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr, -}; - -#[cfg(all( - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "avx" -))] -pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t}; - -#[cfg(all( - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "avx" -))] -pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1}; - - -mod generic; +pub mod avx; +pub mod generic; +pub mod sse; +use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix}; +use crate::dim::Dim; pub mod params { pub mod single { - pub const MC: usize = 256; - pub const KC: usize = 128; - pub const NC: usize = 1024; + pub const MC: usize = 128; + pub const KC: usize = 256; + pub const NC: usize = 8*1024; pub const MR: usize = 16; - pub const NR: usize = 4; + pub const NR: usize = 5; } pub mod double { pub const MC: usize = 256; @@ -57,4 +24,61 @@ pub mod params { pub const MR: usize = 8; pub const NR: usize = 4; } +} + +// +----------------------+ +// |TL : : : |TR| +// | : : : | | +// + - - - - - - - - - ---+ +// | : : : | | +// | : : : | | +// + - - - - - - - - - ---+ +// | : : : | | +// | : : : | | +// +----------------------+ +// |BL | | | |BR| +// +----------------------+ + +pub trait GemmKernelSup { + unsafe fn sup_br, B: Matrix, C: MatrixMut>( + k: usize, + alpha: F, + a: A, + b: B, + beta: F, + c: C, + ); +} + +pub trait GemmKernelSupMr { + unsafe fn sup_bl, C: MatrixMut>( + alpha: F, + pa: MutMatrix, + b: B, + beta: F, + c: C, + ); +} + +pub trait GemmKernelSupNr { + unsafe fn sup_tr, C: MatrixMut>( + alpha: F, + a: A, + pb: MutMatrix, + beta: F, + c: C, + ); +} + +pub trait GemmKernel: GemmKernelSupMr + GemmKernelSupNr + GemmKernelSup { + unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize); + unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize); + + unsafe fn main_tl>( + alpha: F, + pa: MutMatrix, + pb: MutMatrix, + beta: F, + c: C, + ); } \ No newline at end of file diff --git a/src/kernel/sse/fma.rs b/src/kernel/sse/fma.rs new file mode 100644 index 0000000..793ccda --- /dev/null +++ b/src/kernel/sse/fma.rs @@ -0,0 +1,47 @@ +use super::intrinsics::*; + +#[cfg(target_feature = "fma")] +#[inline(always)] +pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 { + _mm_fmadd_ps(a, b, c) +} + +#[cfg(not(target_feature = "fma"))] +#[inline(always)] +pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 { + _mm_add_ps(_mm_mul_ps(a, b), c) +} + +#[cfg(target_feature = "fma")] +#[inline(always)] +pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 { + _mm_fmsub_ps(a, b, c) +} + +#[cfg(not(target_feature = "fma"))] +pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 { + _mm_sub_ps(_mm_mul_ps(a, b), c) +} + +#[cfg(target_feature = "fma")] +#[inline(always)] +pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d { + _mm_fmadd_pd(a, b, c) +} + +#[cfg(not(target_feature = "fma"))] +#[inline(always)] +pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d { + _mm_add_pd(_mm_mul_pd(a, b), c) +} + +#[cfg(target_feature = "fma")] +#[inline(always)] +pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d { + _mm_fmsub_pd(a, b, c) +} + +#[cfg(not(target_feature = "fma"))] +pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d { + _mm_sub_pd(_mm_mul_pd(a, b), c) +} diff --git a/src/kernel/sse/hsum.rs b/src/kernel/sse/hsum.rs new file mode 100644 index 0000000..7ef8dae --- /dev/null +++ b/src/kernel/sse/hsum.rs @@ -0,0 +1,23 @@ +use super::intrinsics::*; + +#[inline(always)] +pub unsafe fn hsum_ps(v: __m256) -> f32 { + let qhigh = _mm256_extractf128_ps(v, 1); + let qlow = _mm256_castps256_ps128(v); + let qsum = _mm_add_ps(qhigh, qlow); + let dhigh = _mm_movehl_ps(qsum, qsum); + let dlow = qsum; + let dsum = _mm_add_ps(dhigh, dlow); + let high = _mm_shuffle_ps(dsum, dsum, 1); + let low = dsum; + _mm_cvtss_f32(_mm_add_ss(high, low)) +} + +#[inline(always)] +pub unsafe fn hsum_pd(v: __m256d) -> f64 { + let vhigh = _mm256_extractf128_pd(v, 1); + let vlow = _mm256_castpd256_pd128(v); + let vsum = _mm_add_pd(vlow, vhigh); + let h64 = _mm_unpackhi_pd(vsum, vsum); + _mm_cvtsd_f64(_mm_add_sd(vsum, h64)) +} diff --git a/src/kernel/sse/intrinsics.rs b/src/kernel/sse/intrinsics.rs new file mode 100644 index 0000000..e95292a --- /dev/null +++ b/src/kernel/sse/intrinsics.rs @@ -0,0 +1,5 @@ +#[cfg(target_arch = "x86")] +pub use std::arch::x86::*; + +#[cfg(target_arch = "x86_64")] +pub use std::arch::x86_64::*; diff --git a/src/kernel/sse/l1d.rs b/src/kernel/sse/l1d.rs new file mode 100644 index 0000000..eb77ab2 --- /dev/null +++ b/src/kernel/sse/l1d.rs @@ -0,0 +1,396 @@ +use super::fma::{fmadd_pd, fmsub_pd}; +use super::hsum::hsum_pd; +use super::intrinsics::*; + +pub unsafe fn drot( + n: usize, + mut x: *mut f64, + incx: usize, + mut y: *mut f64, + incy: usize, + c: f64, + s: f64, +) { + if incx == 1 && incy == 1 { + let c0 = _mm256_broadcast_sd(&c); + let s0 = _mm256_broadcast_sd(&s); + + for _ in 0..n / 16 { + let x0 = _mm256_loadu_pd(x); + let y0 = _mm256_loadu_pd(y); + let x1 = _mm256_loadu_pd(x.add(4)); + let y1 = _mm256_loadu_pd(y.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let y2 = _mm256_loadu_pd(y.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + let y3 = _mm256_loadu_pd(y.add(12)); + + _mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0))); + _mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0))); + _mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1))); + _mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1))); + _mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2))); + _mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2))); + _mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3))); + _mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3))); + + x = x.add(16); + y = y.add(16); + } + + for _ in 0..n % 16 { + let x0 = *x; + let y0 = *y; + + *x = c * x0 + s * y0; + *y = c * y0 - s * x0; + + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + let x0 = *x; + let y0 = *y; + + *x = c * x0 + s * y0; + *y = c * y0 - s * x0; + + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) { + if incx == 1 && incy == 1 { + for _ in 0..n / 16 { + let x0 = _mm256_loadu_pd(x); + let y0 = _mm256_loadu_pd(y); + let x1 = _mm256_loadu_pd(x.add(4)); + let y1 = _mm256_loadu_pd(y.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let y2 = _mm256_loadu_pd(y.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + let y3 = _mm256_loadu_pd(y.add(12)); + + _mm256_storeu_pd(x, y0); + _mm256_storeu_pd(y, x0); + _mm256_storeu_pd(x.add(4), y1); + _mm256_storeu_pd(y.add(4), x1); + _mm256_storeu_pd(x.add(8), y2); + _mm256_storeu_pd(y.add(8), x2); + _mm256_storeu_pd(x.add(12), y3); + _mm256_storeu_pd(y.add(12), x3); + + x = x.add(16); + y = y.add(16); + } + + for _ in 0..n % 16 { + let x0 = *x; + + *x = *y; + *y = x0; + + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + let x0 = *x; + + *x = *y; + *y = x0; + + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) { + if incx == 1 { + let a0 = _mm256_broadcast_sd(&a); + for _ in 0..n / 32 { + let mut x0 = _mm256_loadu_pd(x); + let mut x1 = _mm256_loadu_pd(x.add(4)); + let mut x2 = _mm256_loadu_pd(x.add(8)); + let mut x3 = _mm256_loadu_pd(x.add(12)); + let mut x4 = _mm256_loadu_pd(x.add(16)); + let mut x5 = _mm256_loadu_pd(x.add(20)); + let mut x6 = _mm256_loadu_pd(x.add(24)); + let mut x7 = _mm256_loadu_pd(x.add(28)); + + x0 = _mm256_mul_pd(a0, x0); + x1 = _mm256_mul_pd(a0, x1); + x2 = _mm256_mul_pd(a0, x2); + x3 = _mm256_mul_pd(a0, x3); + x4 = _mm256_mul_pd(a0, x4); + x5 = _mm256_mul_pd(a0, x5); + x6 = _mm256_mul_pd(a0, x6); + x7 = _mm256_mul_pd(a0, x7); + + _mm256_storeu_pd(x, x0); + _mm256_storeu_pd(x.add(4), x1); + _mm256_storeu_pd(x.add(8), x2); + _mm256_storeu_pd(x.add(12), x3); + _mm256_storeu_pd(x.add(16), x4); + _mm256_storeu_pd(x.add(20), x5); + _mm256_storeu_pd(x.add(24), x6); + _mm256_storeu_pd(x.add(28), x7); + + x = x.add(32); + } + for _ in 0..n % 32 { + *x *= a; + x = x.add(1); + } + } else { + for _ in 0..n { + *x *= a; + x = x.add(incx); + } + } +} + +pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) { + if incx == 1 && incy == 1 { + for _ in 0..n / 32 { + let x0 = _mm256_loadu_pd(x); + let x1 = _mm256_loadu_pd(x.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + let x4 = _mm256_loadu_pd(x.add(16)); + let x5 = _mm256_loadu_pd(x.add(20)); + let x6 = _mm256_loadu_pd(x.add(24)); + let x7 = _mm256_loadu_pd(x.add(28)); + + _mm256_storeu_pd(y, x0); + _mm256_storeu_pd(y.add(4), x1); + _mm256_storeu_pd(y.add(8), x2); + _mm256_storeu_pd(y.add(12), x3); + _mm256_storeu_pd(y.add(16), x4); + _mm256_storeu_pd(y.add(20), x5); + _mm256_storeu_pd(y.add(24), x6); + _mm256_storeu_pd(y.add(28), x7); + + x = x.add(32); + y = y.add(32); + } + for _ in 0..n % 32 { + *y = *x; + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + *y = *x; + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn daxpy( + n: usize, + a: f64, + mut x: *const f64, + incx: usize, + mut y: *mut f64, + incy: usize, +) { + if incx == 1 && incy == 1 { + let a0 = _mm256_broadcast_sd(&a); + for _ in 0..n / 16 { + let x0 = _mm256_loadu_pd(x); + let y0 = _mm256_loadu_pd(y); + let x1 = _mm256_loadu_pd(x.add(4)); + let y1 = _mm256_loadu_pd(y.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let y2 = _mm256_loadu_pd(y.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + let y3 = _mm256_loadu_pd(y.add(12)); + + _mm256_storeu_pd(y, fmadd_pd(a0, x0, y0)); + _mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1)); + _mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2)); + _mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3)); + + x = x.add(16); + y = y.add(16); + } + for _ in 0..n % 16 { + *y += a * *x; + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + *y += a * *x; + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn ddot( + n: usize, + mut x: *const f64, + incx: usize, + mut y: *const f64, + incy: usize, +) -> f64 { + if incx == 1 && incy == 1 { + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut acc2 = _mm256_setzero_pd(); + let mut acc3 = _mm256_setzero_pd(); + for _ in 0..n / 16 { + let x0 = _mm256_loadu_pd(x); + let y0 = _mm256_loadu_pd(y); + let x1 = _mm256_loadu_pd(x.add(4)); + let y1 = _mm256_loadu_pd(y.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let y2 = _mm256_loadu_pd(y.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + let y3 = _mm256_loadu_pd(y.add(12)); + + acc0 = fmadd_pd(x0, y0, acc0); + acc1 = fmadd_pd(x1, y1, acc1); + acc2 = fmadd_pd(x2, y2, acc2); + acc3 = fmadd_pd(x3, y3, acc3); + + x = x.add(16); + y = y.add(16); + } + acc0 = _mm256_add_pd(acc0, acc1); + acc2 = _mm256_add_pd(acc2, acc3); + acc0 = _mm256_add_pd(acc0, acc2); + + let mut acc = hsum_pd(acc0); + for _ in 0..n % 16 { + acc += *x * *y; + x = x.add(1); + y = y.add(1); + } + acc + } else { + let mut acc = 0.0; + for _ in 0..n { + acc += *x * *y; + x = x.add(incx); + y = y.add(incy); + } + acc + } +} + +pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 { + if incx == 1 { + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut acc2 = _mm256_setzero_pd(); + let mut acc3 = _mm256_setzero_pd(); + for _ in 0..n / 16 { + let x0 = _mm256_loadu_pd(x); + let x1 = _mm256_loadu_pd(x.add(4)); + let x2 = _mm256_loadu_pd(x.add(8)); + let x3 = _mm256_loadu_pd(x.add(12)); + + acc0 = fmadd_pd(x0, x0, acc0); + acc1 = fmadd_pd(x1, x1, acc1); + acc2 = fmadd_pd(x2, x2, acc2); + acc3 = fmadd_pd(x3, x3, acc3); + + x = x.add(16); + } + acc0 = _mm256_add_pd(acc0, acc1); + acc2 = _mm256_add_pd(acc2, acc3); + acc0 = _mm256_add_pd(acc0, acc2); + + let mut acc = hsum_pd(acc0); + for _ in 0..n % 16 { + let x0 = *x; + acc += x0 * x0; + x = x.add(1); + } + acc.sqrt() + } else { + let mut acc = 0.0; + for _ in 0..n { + let x0 = *x; + acc += x0 * x0; + x = x.add(incx); + } + acc.sqrt() + } +} + +pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 { + if incx == 1 { + let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF)); + + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut acc2 = _mm256_setzero_pd(); + let mut acc3 = _mm256_setzero_pd(); + let mut acc4 = _mm256_setzero_pd(); + let mut acc5 = _mm256_setzero_pd(); + let mut acc6 = _mm256_setzero_pd(); + let mut acc7 = _mm256_setzero_pd(); + for _ in 0..n / 32 { + let mut x0 = _mm256_loadu_pd(x); + let mut x1 = _mm256_loadu_pd(x.add(4)); + let mut x2 = _mm256_loadu_pd(x.add(8)); + let mut x3 = _mm256_loadu_pd(x.add(12)); + let mut x4 = _mm256_loadu_pd(x.add(16)); + let mut x5 = _mm256_loadu_pd(x.add(20)); + let mut x6 = _mm256_loadu_pd(x.add(24)); + let mut x7 = _mm256_loadu_pd(x.add(28)); + + x0 = _mm256_and_pd(mask, x0); + x1 = _mm256_and_pd(mask, x1); + x2 = _mm256_and_pd(mask, x2); + x3 = _mm256_and_pd(mask, x3); + x4 = _mm256_and_pd(mask, x4); + x5 = _mm256_and_pd(mask, x5); + x6 = _mm256_and_pd(mask, x6); + x7 = _mm256_and_pd(mask, x7); + + acc0 = _mm256_add_pd(acc0, x0); + acc1 = _mm256_add_pd(acc1, x1); + acc2 = _mm256_add_pd(acc2, x2); + acc3 = _mm256_add_pd(acc3, x3); + acc4 = _mm256_add_pd(acc4, x4); + acc5 = _mm256_add_pd(acc5, x5); + acc6 = _mm256_add_pd(acc6, x6); + acc7 = _mm256_add_pd(acc7, x7); + + x = x.add(32); + } + acc0 = _mm256_add_pd(acc0, acc1); + acc2 = _mm256_add_pd(acc2, acc3); + acc4 = _mm256_add_pd(acc4, acc5); + acc6 = _mm256_add_pd(acc6, acc7); + + acc0 = _mm256_add_pd(acc0, acc2); + acc4 = _mm256_add_pd(acc4, acc6); + + acc0 = _mm256_add_pd(acc0, acc4); + + let mut acc = hsum_pd(acc0); + for _ in 0..n % 32 { + acc += (*x).abs(); + x = x.add(1); + } + acc + } else { + let mut acc = 0.0; + for _ in 0..n { + acc += (*x).abs(); + x = x.add(incx); + } + acc + } +} diff --git a/src/kernel/sse/l1s.rs b/src/kernel/sse/l1s.rs new file mode 100644 index 0000000..7384396 --- /dev/null +++ b/src/kernel/sse/l1s.rs @@ -0,0 +1,396 @@ +use super::fma::{fmadd_ps, fmsub_ps}; +use super::hsum::hsum_ps; +use super::intrinsics::*; + +pub unsafe fn srot( + n: usize, + mut x: *mut f32, + incx: usize, + mut y: *mut f32, + incy: usize, + c: f32, + s: f32, +) { + if incx == 1 && incy == 1 { + let c0 = _mm256_broadcast_ss(&c); + let s0 = _mm256_broadcast_ss(&s); + + for _ in 0..n / 32 { + let x0 = _mm256_loadu_ps(x); + let y0 = _mm256_loadu_ps(y); + let x1 = _mm256_loadu_ps(x.add(8)); + let y1 = _mm256_loadu_ps(y.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let y2 = _mm256_loadu_ps(y.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + let y3 = _mm256_loadu_ps(y.add(24)); + + _mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0))); + _mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0))); + _mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1))); + _mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1))); + _mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2))); + _mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2))); + _mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3))); + _mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3))); + + x = x.add(32); + y = y.add(32); + } + + for _ in 0..n % 32 { + let x0 = *x; + let y0 = *y; + + *x = c * x0 + s * y0; + *y = c * y0 - s * x0; + + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + let x0 = *x; + let y0 = *y; + + *x = c * x0 + s * y0; + *y = c * y0 - s * x0; + + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) { + if incx == 1 && incy == 1 { + for _ in 0..n / 32 { + let x0 = _mm256_loadu_ps(x); + let y0 = _mm256_loadu_ps(y); + let x1 = _mm256_loadu_ps(x.add(8)); + let y1 = _mm256_loadu_ps(y.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let y2 = _mm256_loadu_ps(y.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + let y3 = _mm256_loadu_ps(y.add(24)); + + _mm256_storeu_ps(x, y0); + _mm256_storeu_ps(y, x0); + _mm256_storeu_ps(x.add(8), y1); + _mm256_storeu_ps(y.add(8), x1); + _mm256_storeu_ps(x.add(16), y2); + _mm256_storeu_ps(y.add(16), x2); + _mm256_storeu_ps(x.add(24), y3); + _mm256_storeu_ps(y.add(24), x3); + + x = x.add(32); + y = y.add(32); + } + + for _ in 0..n % 32 { + let x0 = *x; + + *x = *y; + *y = x0; + + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + let x0 = *x; + + *x = *y; + *y = x0; + + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) { + if incx == 1 { + let a0 = _mm256_broadcast_ss(&a); + for _ in 0..n / 64 { + let mut x0 = _mm256_loadu_ps(x); + let mut x1 = _mm256_loadu_ps(x.add(8)); + let mut x2 = _mm256_loadu_ps(x.add(16)); + let mut x3 = _mm256_loadu_ps(x.add(24)); + let mut x4 = _mm256_loadu_ps(x.add(32)); + let mut x5 = _mm256_loadu_ps(x.add(40)); + let mut x6 = _mm256_loadu_ps(x.add(48)); + let mut x7 = _mm256_loadu_ps(x.add(56)); + + x0 = _mm256_mul_ps(a0, x0); + x1 = _mm256_mul_ps(a0, x1); + x2 = _mm256_mul_ps(a0, x2); + x3 = _mm256_mul_ps(a0, x3); + x4 = _mm256_mul_ps(a0, x4); + x5 = _mm256_mul_ps(a0, x5); + x6 = _mm256_mul_ps(a0, x6); + x7 = _mm256_mul_ps(a0, x7); + + _mm256_storeu_ps(x, x0); + _mm256_storeu_ps(x.add(8), x1); + _mm256_storeu_ps(x.add(16), x2); + _mm256_storeu_ps(x.add(24), x3); + _mm256_storeu_ps(x.add(32), x4); + _mm256_storeu_ps(x.add(40), x5); + _mm256_storeu_ps(x.add(48), x6); + _mm256_storeu_ps(x.add(56), x7); + + x = x.add(64); + } + for _ in 0..n % 64 { + *x *= a; + x = x.add(1); + } + } else { + for _ in 0..n { + *x *= a; + x = x.add(incx); + } + } +} + +pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) { + if incx == 1 && incy == 1 { + for _ in 0..n / 64 { + let x0 = _mm256_loadu_ps(x); + let x1 = _mm256_loadu_ps(x.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + let x4 = _mm256_loadu_ps(x.add(32)); + let x5 = _mm256_loadu_ps(x.add(40)); + let x6 = _mm256_loadu_ps(x.add(48)); + let x7 = _mm256_loadu_ps(x.add(56)); + + _mm256_storeu_ps(y, x0); + _mm256_storeu_ps(y.add(8), x1); + _mm256_storeu_ps(y.add(16), x2); + _mm256_storeu_ps(y.add(24), x3); + _mm256_storeu_ps(y.add(32), x4); + _mm256_storeu_ps(y.add(40), x5); + _mm256_storeu_ps(y.add(48), x6); + _mm256_storeu_ps(y.add(56), x7); + + x = x.add(64); + y = y.add(64); + } + for _ in 0..n % 64 { + *y = *x; + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + *y = *x; + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn saxpy( + n: usize, + a: f32, + mut x: *const f32, + incx: usize, + mut y: *mut f32, + incy: usize, +) { + if incx == 1 && incy == 1 { + let a0 = _mm256_broadcast_ss(&a); + for _ in 0..n / 32 { + let x0 = _mm256_loadu_ps(x); + let y0 = _mm256_loadu_ps(y); + let x1 = _mm256_loadu_ps(x.add(8)); + let y1 = _mm256_loadu_ps(y.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let y2 = _mm256_loadu_ps(y.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + let y3 = _mm256_loadu_ps(y.add(24)); + + _mm256_storeu_ps(y, fmadd_ps(a0, x0, y0)); + _mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1)); + _mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2)); + _mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3)); + + x = x.add(32); + y = y.add(32); + } + for _ in 0..n % 32 { + *y += a * *x; + x = x.add(1); + y = y.add(1); + } + } else { + for _ in 0..n { + *y += a * *x; + x = x.add(incx); + y = y.add(incy); + } + } +} + +pub unsafe fn sdot( + n: usize, + mut x: *const f32, + incx: usize, + mut y: *const f32, + incy: usize, +) -> f32 { + if incx == 1 && incy == 1 { + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + for _ in 0..n / 32 { + let x0 = _mm256_loadu_ps(x); + let y0 = _mm256_loadu_ps(y); + let x1 = _mm256_loadu_ps(x.add(8)); + let y1 = _mm256_loadu_ps(y.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let y2 = _mm256_loadu_ps(y.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + let y3 = _mm256_loadu_ps(y.add(24)); + + acc0 = fmadd_ps(x0, y0, acc0); + acc1 = fmadd_ps(x1, y1, acc1); + acc2 = fmadd_ps(x2, y2, acc2); + acc3 = fmadd_ps(x3, y3, acc3); + + x = x.add(32); + y = y.add(32); + } + acc0 = _mm256_add_ps(acc0, acc1); + acc2 = _mm256_add_ps(acc2, acc3); + acc0 = _mm256_add_ps(acc0, acc2); + + let mut acc = hsum_ps(acc0); + for _ in 0..n % 32 { + acc += *x * *y; + x = x.add(1); + y = y.add(1); + } + acc + } else { + let mut acc = 0.0; + for _ in 0..n { + acc += *x * *y; + x = x.add(incx); + y = y.add(incy); + } + acc + } +} + +pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 { + if incx == 1 { + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + for _ in 0..n / 32 { + let x0 = _mm256_loadu_ps(x); + let x1 = _mm256_loadu_ps(x.add(8)); + let x2 = _mm256_loadu_ps(x.add(16)); + let x3 = _mm256_loadu_ps(x.add(24)); + + acc0 = fmadd_ps(x0, x0, acc0); + acc1 = fmadd_ps(x1, x1, acc1); + acc2 = fmadd_ps(x2, x2, acc2); + acc3 = fmadd_ps(x3, x3, acc3); + + x = x.add(32); + } + acc0 = _mm256_add_ps(acc0, acc1); + acc2 = _mm256_add_ps(acc2, acc3); + acc0 = _mm256_add_ps(acc0, acc2); + + let mut acc = hsum_ps(acc0); + for _ in 0..n % 32 { + let x0 = *x; + acc += x0 * x0; + x = x.add(1); + } + acc.sqrt() + } else { + let mut acc = 0.0; + for _ in 0..n { + let x0 = *x; + acc += x0 * x0; + x = x.add(incx); + } + acc.sqrt() + } +} + +pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 { + if incx == 1 { + let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF)); + + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + let mut acc4 = _mm256_setzero_ps(); + let mut acc5 = _mm256_setzero_ps(); + let mut acc6 = _mm256_setzero_ps(); + let mut acc7 = _mm256_setzero_ps(); + for _ in 0..n / 64 { + let mut x0 = _mm256_loadu_ps(x); + let mut x1 = _mm256_loadu_ps(x.add(8)); + let mut x2 = _mm256_loadu_ps(x.add(16)); + let mut x3 = _mm256_loadu_ps(x.add(24)); + let mut x4 = _mm256_loadu_ps(x.add(32)); + let mut x5 = _mm256_loadu_ps(x.add(40)); + let mut x6 = _mm256_loadu_ps(x.add(48)); + let mut x7 = _mm256_loadu_ps(x.add(56)); + + x0 = _mm256_and_ps(mask, x0); + x1 = _mm256_and_ps(mask, x1); + x2 = _mm256_and_ps(mask, x2); + x3 = _mm256_and_ps(mask, x3); + x4 = _mm256_and_ps(mask, x4); + x5 = _mm256_and_ps(mask, x5); + x6 = _mm256_and_ps(mask, x6); + x7 = _mm256_and_ps(mask, x7); + + acc0 = _mm256_add_ps(acc0, x0); + acc1 = _mm256_add_ps(acc1, x1); + acc2 = _mm256_add_ps(acc2, x2); + acc3 = _mm256_add_ps(acc3, x3); + acc4 = _mm256_add_ps(acc4, x4); + acc5 = _mm256_add_ps(acc5, x5); + acc6 = _mm256_add_ps(acc6, x6); + acc7 = _mm256_add_ps(acc7, x7); + + x = x.add(64); + } + acc0 = _mm256_add_ps(acc0, acc1); + acc2 = _mm256_add_ps(acc2, acc3); + acc4 = _mm256_add_ps(acc4, acc5); + acc6 = _mm256_add_ps(acc6, acc7); + + acc0 = _mm256_add_ps(acc0, acc2); + acc4 = _mm256_add_ps(acc4, acc6); + + acc0 = _mm256_add_ps(acc0, acc4); + + let mut acc = hsum_ps(acc0); + for _ in 0..n % 64 { + acc += (*x).abs(); + x = x.add(1); + } + acc + } else { + let mut acc = 0.0; + for _ in 0..n { + acc += (*x).abs(); + x = x.add(incx); + } + acc + } +} diff --git a/src/kernel/sse/l3d.rs b/src/kernel/sse/l3d.rs new file mode 100644 index 0000000..82e7238 --- /dev/null +++ b/src/kernel/sse/l3d.rs @@ -0,0 +1,140 @@ +use super::fma::fmadd_pd; +use super::intrinsics::*; + +pub(crate) unsafe fn dgemm_ukr_8x4( + k: usize, + alpha: f64, + pa: *const f64, + pb: *const f64, + beta: f64, + c: *mut f64, + ldc: usize, +) { + let mut mt00 = _mm256_setzero_pd(); + let mut mt01 = _mm256_setzero_pd(); + let mut mt02 = _mm256_setzero_pd(); + let mut mt03 = _mm256_setzero_pd(); + let mut mt10 = _mm256_setzero_pd(); + let mut mt11 = _mm256_setzero_pd(); + let mut mt12 = _mm256_setzero_pd(); + let mut mt13 = _mm256_setzero_pd(); + + let mut pa = pa; + let mut pb = pb; + + for _ in 0..k { + let a0 = _mm256_loadu_pd(pa); + let a1 = _mm256_loadu_pd(pa.add(4)); + + let b0 = _mm256_broadcast_sd(&*pb); + let b1 = _mm256_broadcast_sd(&*pb.add(1)); + let b2 = _mm256_broadcast_sd(&*pb.add(2)); + let b3 = _mm256_broadcast_sd(&*pb.add(3)); + + mt00 = fmadd_pd(a0, b0, mt00); + mt01 = fmadd_pd(a0, b1, mt01); + mt02 = fmadd_pd(a0, b2, mt02); + mt03 = fmadd_pd(a0, b3, mt03); + mt10 = fmadd_pd(a1, b0, mt10); + mt11 = fmadd_pd(a1, b1, mt11); + mt12 = fmadd_pd(a1, b2, mt12); + mt13 = fmadd_pd(a1, b3, mt13); + + pa = pa.add(8); + pb = pb.add(4); + } + + let alpha = _mm256_broadcast_sd(&alpha); + + mt00 = _mm256_mul_pd(alpha, mt00); + mt01 = _mm256_mul_pd(alpha, mt01); + mt02 = _mm256_mul_pd(alpha, mt02); + mt03 = _mm256_mul_pd(alpha, mt03); + mt10 = _mm256_mul_pd(alpha, mt10); + mt11 = _mm256_mul_pd(alpha, mt11); + mt12 = _mm256_mul_pd(alpha, mt12); + mt13 = _mm256_mul_pd(alpha, mt13); + + let ccol0 = c; + let ccol1 = c.add(ldc); + let ccol2 = c.add(ldc * 2); + let ccol3 = c.add(ldc * 3); + + if beta != 0.0 { + let beta = _mm256_broadcast_sd(&beta); + + mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00); + mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01); + mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02); + mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03); + mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10); + mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11); + mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12); + mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13); + } + + _mm256_storeu_pd(ccol0, mt00); + _mm256_storeu_pd(ccol1, mt01); + _mm256_storeu_pd(ccol2, mt02); + _mm256_storeu_pd(ccol3, mt03); + _mm256_storeu_pd(ccol0.add(4), mt10); + _mm256_storeu_pd(ccol1.add(4), mt11); + _mm256_storeu_pd(ccol2.add(4), mt12); + _mm256_storeu_pd(ccol3.add(4), mt13); +} + +pub(crate) unsafe fn dgemm_sup_8x1( + k: usize, + alpha: f64, + pa: *const f64, + b: *const f64, + beta: f64, + c: *mut f64, +) { + let mut mt0 = _mm256_setzero_pd(); + let mut mt1 = _mm256_setzero_pd(); + + let mut pa = pa; + let mut b = b; + + for _ in 0..k { + let a0 = _mm256_loadu_pd(pa); + let a1 = _mm256_loadu_pd(pa.add(4)); + + let b0 = _mm256_broadcast_sd(&*b); + + mt0 = fmadd_pd(a0, b0, mt0); + mt1 = fmadd_pd(a1, b0, mt1); + + pa = pa.add(8); + b = b.add(1); + } + + let alpha = _mm256_broadcast_sd(&alpha); + + mt0 = _mm256_mul_pd(alpha, mt0); + mt1 = _mm256_mul_pd(alpha, mt1); + + if beta != 0.0 { + let beta = _mm256_broadcast_sd(&beta); + + mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0); + mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1); + } + + _mm256_storeu_pd(c, mt0); + _mm256_storeu_pd(c.add(4), mt1); +} + +pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) { + let mut a = a; + let mut pa = pa; + + for _ in 0..k { + _mm256_storeu_pd(pa, _mm256_loadu_pd(a)); + _mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4))); + + pa = pa.add(8); + a = a.add(lda); + } +} diff --git a/src/kernel/sse/l3s.rs b/src/kernel/sse/l3s.rs new file mode 100644 index 0000000..8ec2471 --- /dev/null +++ b/src/kernel/sse/l3s.rs @@ -0,0 +1,178 @@ +use super::fma::fmadd_ps; +use super::intrinsics::*; +use crate::kernel::params::single::NR; +use crate::matrix::{Matrix, MatrixMut, MutMatrix}; + +pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( + k: usize, + alpha: f32, + a: A, + pb: MutMatrix, + beta: f32, + c: C, +) { + let mut c0_3 = _mm_setzero_ps(); + let mut c4 = 0.0f32; + // let mut c5 = 0.0f32; + // let mut c6 = 0.0f32; + // let mut c7 = 0.0f32; + + let mut a = a; + let mut pb = pb; + + for _ in 0..k { + let a0 = *a.ptr(); + let a0_simd = _mm_broadcast_ss(&*a.ptr()); + + c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3); + c4 += *pb.col(4) * a0; + // c5 += *pb.add(5) * a0; + // c6 += *pb.add(6) * a0; + // c7 += *pb.add(7) * a0; + + a.inc_row(); + pb.shift_col(NR); + } + + // c0 *= alpha; + // c1 *= alpha; + // c2 *= alpha; + // c3 *= alpha; + // c4 *= alpha; + // c5 *= alpha; + // c6 *= alpha; + // c7 *= alpha; + + let ccol0_3 = c.ptr_mut(); + let ccol4 = c.row_mut(4); + // let ccol5 = c.add(ldc * 5); + // let ccol6 = c.add(ldc * 6); + // let ccol7 = c.add(ldc * 7); + + // if beta != 0.0 { + // c0 += beta * *ccol0; + // c1 += beta * *ccol1; + // c2 += beta * *ccol2; + // c3 += beta * *ccol3; + // c4 += beta * *ccol4; + // // c5 += beta * *ccol5; + // // c6 += beta * *ccol6; + // // c7 += beta * *ccol7; + // } + + _mm_storeu_ps(ccol0_3, c0_3); + *ccol4 = c4; + // *ccol5 = c5; + // *ccol6 = c6; + // *ccol7 = c7; +} + +pub(crate) unsafe fn sgemm_sup_1x8_t( + k: usize, + alpha: f32, + a: *const f32, + lda: usize, + pb: *const f32, + beta: f32, + c: *mut f32, + ldc: usize, +) { + let mut c03 = _mm_setzero_ps(); + let mut c4 = 0.0f32; + // let mut c5 = 0.0f32; + // let mut c6 = 0.0f32; + // let mut c7 = 0.0f32; + + let mut a = a; + let mut pb = pb; + + for _ in 0..k { + let a0 = *a; + let a0_simd = _mm_set_ps1(a0); + + c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03); + c4 += *pb.add(4) * a0; + // c5 += *pb.add(5) * a0; + // c6 += *pb.add(6) * a0; + // c7 += *pb.add(7) * a0; + + a = a.add(1); + pb = pb.add(NR); + } + + // c0 *= alpha; + // c1 *= alpha; + // c2 *= alpha; + // c3 *= alpha; + // c4 *= alpha; + // c5 *= alpha; + // c6 *= alpha; + // c7 *= alpha; + + // let ccol0 = c; + // let ccol1 = c.add(ldc); + // let ccol2 = c.add(ldc * 2); + // let ccol3 = c.add(ldc * 3); + let ccol4 = c.add(ldc * 4); + // let ccol5 = c.add(ldc * 5); + // let ccol6 = c.add(ldc * 6); + // let ccol7 = c.add(ldc * 7); + + // if beta != 0.0 { + // c0 += beta * *ccol0; + // c1 += beta * *ccol1; + // c2 += beta * *ccol2; + // c3 += beta * *ccol3; + // c4 += beta * *ccol4; + // // c5 += beta * *ccol5; + // // c6 += beta * *ccol6; + // // c7 += beta * *ccol7; + // } + + _mm_storeu_ps(c, c03); + // *ccol0 = c0; + // *ccol1 = c1; + // *ccol2 = c2; + // *ccol3 = c3; + *ccol4 = c4; + // *ccol5 = c5; + // *ccol6 = c6; + // *ccol7 = c7; +} + +pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) { + let mut bcol0 = b; + let mut bcol1 = b.add(ldb); + let mut bcol2 = b.add(ldb * 2); + let mut bcol3 = b.add(ldb * 3); + let mut bcol4 = b.add(ldb * 4); + // let mut bcol5 = b.add(ldb * 5); + // let mut bcol6 = b.add(ldb * 6); + // let mut bcol7 = b.add(ldb * 7); + + let mut pb = pb; + + for _ in 0..k { + _mm_storeu_ps(pb, _mm_set_ps(*bcol3, *bcol2, *bcol1, *bcol0)); + + // *pb = *bcol0; + // *pb.add(1) = *bcol1; + // *pb.add(2) = *bcol2; + // *pb.add(3) = *bcol3; + *pb.add(4) = *bcol4; + // *pb.add(5) = *bcol5; + // *pb.add(6) = *bcol6; + // *pb.add(7) = *bcol7; + + bcol0 = bcol0.add(1); + bcol1 = bcol1.add(1); + bcol2 = bcol2.add(1); + bcol3 = bcol3.add(1); + bcol4 = bcol4.add(1); + // bcol5 = bcol5.add(1); + // bcol6 = bcol6.add(1); + // bcol7 = bcol7.add(1); + + pb = pb.add(NR); + } +} \ No newline at end of file diff --git a/src/kernel/sse/mod.rs b/src/kernel/sse/mod.rs new file mode 100644 index 0000000..dbc9170 --- /dev/null +++ b/src/kernel/sse/mod.rs @@ -0,0 +1,86 @@ +mod fma; +// mod hsum; +mod intrinsics; +// pub mod l1d; +// pub mod l1s; +// pub mod l3d; +pub mod l3s; + +use core::marker::PhantomData; +use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut}; +use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup}; +use crate::dim::*; + + +pub struct SseKernel(PhantomData); + +impl GemmKernelSupNr for SseKernel + where I: GemmKernelSupNr +{ + #[inline] + unsafe fn sup_tr, C: MatrixMut>( + alpha: f32, + a: A, + pb: MutMatrix, + beta: f32, + c: C, + ) { + I::sup_tr(alpha, a, pb, beta, c); + } +} + +impl GemmKernelSupMr for SseKernel + where I: GemmKernelSupMr +{ + #[inline] + unsafe fn sup_bl, C: MatrixMut>( + alpha: f32, + pa: MutMatrix, + b: B, + beta: f32, + c: C, + ) { + I::sup_bl(alpha, pa, b, beta, c); + } +} + +impl GemmKernelSup for SseKernel + where I: GemmKernelSup +{ + #[inline] + unsafe fn sup_br, B: Matrix, C: MatrixMut>( + k: usize, + alpha: f32, + a: A, + b: B, + beta: f32, + c: C, + ) { + I::sup_br(k, alpha, a, b, beta, c); + } +} + +impl GemmKernel for SseKernel + where I: GemmKernel +{ + #[inline] + unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { + I::pack_row_a(a, pa, i); + } + + #[inline] + unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { + I::pack_row_b(b, pb, j); + } + + #[inline] + unsafe fn main_tl>( + alpha: f32, + pa: MutMatrix, + pb: MutMatrix, + beta: f32, + c: C, + ) { + I::main_tl(alpha, pa, pb, beta, c); + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0a9b138 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,17 @@ +pub mod matrix; +pub mod gemm; +mod sgemm; +mod aligned_alloc; +pub mod kernel; +pub mod dim; +pub mod executor; + +#[cfg(test)] +extern crate blas; +#[cfg(test)] +extern crate openblas; +#[cfg(test)] +mod test; + + +pub use crate::sgemm::sgemm; diff --git a/src/main.rs b/src/main.rs index ac7f145..3e1832d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,373 +1,16 @@ mod matrix; +mod gemm; mod sgemm; mod aligned_alloc; mod kernel; +mod dim; +mod executor; + extern crate blas; extern crate openblas; -use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix}; - -pub unsafe fn sgemm( - transa: bool, - transb: bool, - transc: bool, - m: usize, - n: usize, - k: usize, - alpha: f32, - a: *const f32, - lda: usize, - b: *const f32, - ldb: usize, - beta: f32, - c: *mut f32, - ldc: usize, -) { - match (transa, transb, transc) { - (false, false, false) => sgemm::sgemm( - m, n, k, alpha, - ConstMatrix::new(a, lda), - ConstMatrix::new(b, ldb), - beta, - MutMatrix::new(c, ldc)), - - (false, false, true) => sgemm::sgemm( - m, n, k, alpha, - ConstMatrix::new(a, lda), - ConstMatrix::new(b, ldb), - beta, - MutTransposedMatrix::new(c, ldc)), - - (false, true, false) => sgemm::sgemm( - m, n, k, alpha, - ConstMatrix::new(a, lda), - ConstTransposedMatrix::new(b, ldb), - beta, - MutMatrix::new(c, ldc)), - - (false, true, true) => sgemm::sgemm( - m, n, k, alpha, - ConstMatrix::new(a, lda), - ConstTransposedMatrix::new(b, ldb), - beta, - MutTransposedMatrix::new(c, ldc)), - - (true, false, false) => sgemm::sgemm( - m, n, k, alpha, - ConstTransposedMatrix::new(a, lda), - ConstMatrix::new(b, ldb), - beta, - MutMatrix::new(c, ldc)), - - (true, false, true) => sgemm::sgemm( - m, n, k, alpha, - ConstTransposedMatrix::new(a, lda), - ConstMatrix::new(b, ldb), - beta, - MutTransposedMatrix::new(c, ldc)), - - - (true, true, false) => sgemm::sgemm( - m, n, k, alpha, - ConstTransposedMatrix::new(a, lda), - ConstTransposedMatrix::new(b, ldb), - beta, - MutMatrix::new(c, ldc)), - - - (true, true, true) => sgemm::sgemm( - m, n, k, alpha, - ConstTransposedMatrix::new(a, lda), - ConstTransposedMatrix::new(b, ldb), - beta, - MutTransposedMatrix::new(c, ldc)), - } -} - - - - - -unsafe fn sgemm_ref_nn( - m: usize, - n: usize, - k: usize, - _alpha: f32, - a: *const f32, - lda: usize, - b: *const f32, - ldb: usize, - _beta: f32, - c: *mut f32, - ldc: usize, -) { - for j in 0..n { - for i in 0..m { - let mut ci = *c.add(i + j * ldc); - for p in 0..k { - ci += *a.add(i + p * lda) * *b.add(p + j * ldb); - } - *c.add(i + j * ldc) = ci; - } - } -} - -unsafe fn sgemm_ref_nt( - m: usize, - n: usize, - k: usize, - _alpha: f32, - a: *const f32, - lda: usize, - b: *const f32, - ldb: usize, - _beta: f32, - c: *mut f32, - ldc: usize, -) { - for j in 0..n { - for i in 0..m { - let mut ci = *c.add(i + j * ldc); - - for p in 0..k { - ci += *a.add(i + p * lda) * *b.add(j + p * ldb); - } - - *c.add(i + j * ldc) = ci; - } - } -} - -#[inline(never)] -pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32, - a: &[f32], lda: usize, - b: &[f32], ldb: usize, - _bata: f32, - c: &mut [f32], ldc: usize) { - - let mk = m * k; - let nk = n * k; - let mn = m * n; - let a = &a[0..mk]; - let b = &b[0..nk]; - let c = &mut c[0..mn]; - - for i_m in 0..m { - for i_k in 0..k { - let a_part = alpha * a[i_m * lda + i_k]; - - for i_n in 0..n { - c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n]; - } - } - } -} - - -fn main2() { - use rand::Rng; - - let mut rng = rand::thread_rng(); - - const LEN: usize = 511; - let (m, n, k) = (LEN, LEN, LEN); - - let mut a = vec![0.5; m * k]; - let mut a_t = vec![0.5; m * k]; - let mut b = vec![0.5; n * k]; - let mut b_t = vec![0.5; n * k]; - let mut c_nn = vec![0.0; m * n]; - let mut c_nt = vec![0.0; m * n]; - let mut c_tn = vec![0.0; m * n]; - let mut c_tt = vec![0.0; m * n]; - let mut cref = vec![0.0; m * n]; - - for i in 0..m { - for j in 0..k { - let v = rng.gen(); - a[i + j * m] = v; - a_t[j + i * m] = v; - } - } - - for i in 0..n { - for j in 0..k { - let v = rng.gen(); - b[i + j * n] = v; - b_t[j + i * n] = v; - } - } - - // let time = std::time::Instant::now(); - // unsafe { - // gemm_nn( - // LEN, - // LEN, - // LEN, - // 1.0, - // a.as_slice(), - // LEN, - // b.as_slice(), - // LEN, - // 1.0, - // cref1.as_mut_slice(), - // LEN, - // ) - // } - - // println!("Naive (mine) {}", time.elapsed().as_millis()); - - let time = std::time::Instant::now(); - unsafe { - sgemm_ref_nn( - m, - n, - k, - - 1.0, - a.as_ptr(), - m, - - b.as_ptr(), - k, - - 1.0, - cref.as_mut_ptr(), - m, - ) - } - - println!("Naive {}", time.elapsed().as_millis()); - - let time = std::time::Instant::now(); - unsafe { - sgemm( - false, - false, - false, - m, - n, - k, - - 1.0, - a.as_ptr(), - m, - - b.as_ptr(), - k, - - 1.0, - c_nn.as_mut_ptr(), - m, - ); - } - println!("[NN] Optimized {}", time.elapsed().as_millis()); - - for i in 0..LEN { - for j in 0..LEN { - let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]); - assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b); - } - } - - let time = std::time::Instant::now(); - unsafe { - sgemm( - false, - true, - false, - m, - n, - k, - - 1.0, - a.as_ptr(), - m, - - b_t.as_ptr(), - n, - - 1.0, - c_nt.as_mut_ptr(), - m, - ); - } - - println!("[NT] Optimized {}", time.elapsed().as_millis()); - - for i in 0..LEN { - for j in 0..LEN { - let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]); - assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b); - } - } - - let time = std::time::Instant::now(); - unsafe { - sgemm( - true, - false, - false, - m, - n, - k, - - 1.0, - a_t.as_ptr(), - k, - - b.as_ptr(), - k, - - 1.0, - c_tn.as_mut_ptr(), - m, - ); - } - - println!("[TN] Optimized {}", time.elapsed().as_millis()); - - for i in 0..LEN { - for j in 0..LEN { - let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]); - assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b); - } - } - - let time = std::time::Instant::now(); - unsafe { - sgemm( - true, - true, - false, - m, - n, - k, - - 1.0, - a_t.as_ptr(), - k, - - b_t.as_ptr(), - n, - - 1.0, - c_tt.as_mut_ptr(), - m, - ); - } - println!("[TT] Optimized {}", time.elapsed().as_millis()); - - - for i in 0..LEN { - for j in 0..LEN { - let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]); - assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b); - } - } -} - +use crate::sgemm::sgemm; fn main() { use rand::Rng; @@ -395,25 +38,24 @@ fn main() { } } - - let time = std::time::Instant::now(); unsafe { - // blas::sgemm( - // b'N', - // b'N', - // m as i32, - // n as i32, - // k as i32, - // 1.0, - // a.as_slice(), - // m as i32, - // b.as_slice(), - // k as i32, - // 1.0, - // cref1.as_mut_slice(), - // m as i32, - // ); + blas::sgemm( + b'N', + b'N', + m as i32, + n as i32, + k as i32, + 1.0, + a.as_slice(), + m as i32, + b.as_slice(), + k as i32, + 1.0, + cref1.as_mut_slice(), + m as i32, + ); + // gemm_nn( // m, // n, @@ -456,6 +98,7 @@ fn main() { let time = std::time::Instant::now(); unsafe { sgemm( + &executor::RayonExecutor, false, false, false, @@ -477,22 +120,5 @@ fn main() { } println!("Optimized {}", time.elapsed().as_millis()); - - - // for i in 0..LEN { - // for j in 0..LEN { - // let (a, b) = (c[i + j * LEN], cref[i + j * LEN]); - // assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b); - // } - // } } -fn feq(a: f32, b: f32) -> bool { - if a == b { - true - } else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) { - (a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE - } else { - (a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0 - } -} diff --git a/src/matrix.rs b/src/matrix.rs index 9ca71ec..9ef97ea 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,36 +1,50 @@ -pub(crate) trait Matrix: Copy { + +pub trait Number: Copy + Send + Sync + 'static {} + +impl Number for f32 {} +impl Number for f64 {} + +pub trait Matrix: Copy + Send + Sync + 'static { unsafe fn sub(&self, row: usize, col: usize) -> Self; + unsafe fn sub_col(&self, col: usize) -> Self; + unsafe fn sub_row(&self, row: usize) -> Self; unsafe fn is_transposed(&self) -> bool; unsafe fn stride(&self) -> usize; - unsafe fn get(&self, row: usize, col: usize) -> f32; - unsafe fn index(&self, row: usize, col: usize) -> *const f32; - unsafe fn row(&self, row: usize) -> *const f32; - unsafe fn col(&self, col: usize) -> *const f32; + unsafe fn get(&self, row: usize, col: usize) -> F; + unsafe fn index(&self, row: usize, col: usize) -> *const F; + unsafe fn row(&self, row: usize) -> *const F; + unsafe fn col(&self, col: usize) -> *const F; + unsafe fn ptr(&self) -> *const F; + unsafe fn inc_row(&mut self); + unsafe fn inc_col(&mut self); + unsafe fn shift_row(&mut self, rows: usize); + unsafe fn shift_col(&mut self, cols: usize); } -pub(crate) trait MatrixMut: Matrix { - unsafe fn set(&self, row: usize, col: usize, val: f32); - unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32; - unsafe fn row_mut(&self, row: usize) -> *mut f32; - unsafe fn col_mut(&self, col: usize) -> *mut f32; +pub trait MatrixMut: Matrix { + unsafe fn set(&self, row: usize, col: usize, val: F); + unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F; + unsafe fn row_mut(&self, row: usize) -> *mut F; + unsafe fn col_mut(&self, col: usize) -> *mut F; + unsafe fn ptr_mut(&self) -> *mut F; } #[derive(Copy, Clone)] -pub(crate) struct ConstMatrix { +pub struct ConstMatrix { pub stride: usize, - pub ptr: *const f32 + pub ptr: *const F } -unsafe impl Send for ConstMatrix {} -unsafe impl Sync for ConstMatrix {} +unsafe impl Send for ConstMatrix {} +unsafe impl Sync for ConstMatrix {} -impl ConstMatrix { - pub fn new(ptr: *const f32, stride: usize) -> Self { +impl ConstMatrix { + pub fn new(ptr: *const F, stride: usize) -> Self { Self { ptr, stride } } } -impl Matrix for ConstMatrix { +impl Matrix for ConstMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { Self { @@ -39,6 +53,22 @@ impl Matrix for ConstMatrix { } } + #[inline] + unsafe fn sub_col(&self, col: usize) -> Self { + Self { + ptr: self.col(col), + stride: self.stride + } + } + + #[inline] + unsafe fn sub_row(&self, row: usize) -> Self { + Self { + ptr: self.row(row), + stride: self.stride + } + } + #[inline] unsafe fn is_transposed(&self) -> bool { false @@ -50,42 +80,67 @@ impl Matrix for ConstMatrix { } #[inline] - unsafe fn get(&self, row: usize, col: usize) -> f32 { + unsafe fn get(&self, row: usize, col: usize) -> F { *self.ptr.add(row * self.stride + col) } #[inline] - unsafe fn index(&self, row: usize, col: usize) -> *const f32 { + unsafe fn index(&self, row: usize, col: usize) -> *const F { self.ptr.add(row * self.stride + col) } #[inline] - unsafe fn row(&self, row: usize) -> *const f32 { + unsafe fn row(&self, row: usize) -> *const F { self.ptr.add(row * self.stride) } #[inline] - unsafe fn col(&self, col: usize) -> *const f32 { + unsafe fn col(&self, col: usize) -> *const F { self.ptr.add(col) } + + #[inline] + unsafe fn ptr(&self) -> *const F { + self.ptr + } + + #[inline] + unsafe fn inc_row(&mut self) { + self.ptr = self.row(1); + } + + #[inline] + unsafe fn inc_col(&mut self) { + self.ptr = self.col(1); + } + + #[inline] + unsafe fn shift_row(&mut self, rows: usize) { + self.ptr = self.row(rows); + } + + #[inline] + unsafe fn shift_col(&mut self, cols: usize) { + self.ptr = self.col(cols); + } } #[derive(Copy, Clone)] -pub(crate) struct ConstTransposedMatrix { +pub struct ConstTransposedMatrix { pub stride: usize, - pub ptr: *const f32 + pub ptr: *const F } -unsafe impl Send for ConstTransposedMatrix {} -unsafe impl Sync for ConstTransposedMatrix {} +unsafe impl Send for ConstTransposedMatrix {} +unsafe impl Sync for ConstTransposedMatrix {} -impl ConstTransposedMatrix { - pub fn new(ptr: *const f32, stride: usize) -> Self { +impl ConstTransposedMatrix { + pub fn new(ptr: *const F, stride: usize) -> Self { Self { ptr, stride } } } -impl Matrix for ConstTransposedMatrix { +impl Matrix for ConstTransposedMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { Self { @@ -94,6 +149,22 @@ impl Matrix for ConstTransposedMatrix { } } + #[inline] + unsafe fn sub_col(&self, col: usize) -> Self { + Self { + ptr: self.col(col), + stride: self.stride + } + } + + #[inline] + unsafe fn sub_row(&self, row: usize) -> Self { + Self { + ptr: self.row(row), + stride: self.stride + } + } + #[inline] unsafe fn is_transposed(&self) -> bool { true @@ -105,42 +176,67 @@ impl Matrix for ConstTransposedMatrix { } #[inline] - unsafe fn get(&self, row: usize, col: usize) -> f32 { + unsafe fn get(&self, row: usize, col: usize) -> F { *self.ptr.add(col * self.stride + row) } #[inline] - unsafe fn index(&self, row: usize, col: usize) -> *const f32 { + unsafe fn index(&self, row: usize, col: usize) -> *const F { self.ptr.add(col * self.stride + row) } #[inline] - unsafe fn row(&self, row: usize) -> *const f32 { + unsafe fn row(&self, row: usize) -> *const F { self.ptr.add(row) } #[inline] - unsafe fn col(&self, col: usize) -> *const f32 { + unsafe fn col(&self, col: usize) -> *const F { self.ptr.add(col * self.stride) } + + #[inline] + unsafe fn ptr(&self) -> *const F { + self.ptr + } + + #[inline] + unsafe fn inc_row(&mut self) { + self.ptr = self.row(1); + } + + #[inline] + unsafe fn inc_col(&mut self) { + self.ptr = self.col(1); + } + + #[inline] + unsafe fn shift_row(&mut self, rows: usize) { + self.ptr = self.row(rows); + } + + #[inline] + unsafe fn shift_col(&mut self, cols: usize) { + self.ptr = self.col(cols); + } } #[derive(Copy, Clone)] -pub(crate) struct MutMatrix { +pub struct MutMatrix { pub stride: usize, - pub ptr: *mut f32 + pub ptr: *mut F } -unsafe impl Send for MutMatrix {} -unsafe impl Sync for MutMatrix {} +unsafe impl Send for MutMatrix {} +unsafe impl Sync for MutMatrix {} -impl MutMatrix { - pub fn new(ptr: *mut f32, stride: usize) -> Self { +impl MutMatrix { + pub fn new(ptr: *mut F, stride: usize) -> Self { Self { ptr, stride } } } -impl Matrix for MutMatrix { +impl Matrix for MutMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { Self { @@ -149,6 +245,22 @@ impl Matrix for MutMatrix { } } + #[inline] + unsafe fn sub_col(&self, col: usize) -> Self { + Self { + ptr: self.col_mut(col), + stride: self.stride + } + } + + #[inline] + unsafe fn sub_row(&self, row: usize) -> Self { + Self { + ptr: self.row_mut(row), + stride: self.stride + } + } + #[inline] unsafe fn is_transposed(&self) -> bool { false @@ -160,64 +272,94 @@ impl Matrix for MutMatrix { } #[inline] - unsafe fn get(&self, row: usize, col: usize) -> f32 { + unsafe fn get(&self, row: usize, col: usize) -> F { *self.ptr.add(row * self.stride + col) } #[inline] - unsafe fn index(&self, row: usize, col: usize) -> *const f32 { + unsafe fn index(&self, row: usize, col: usize) -> *const F { self.ptr.add(row * self.stride + col) } #[inline] - unsafe fn row(&self, row: usize) -> *const f32 { + unsafe fn row(&self, row: usize) -> *const F { self.ptr.add(row * self.stride) } #[inline] - unsafe fn col(&self, col: usize) -> *const f32 { + unsafe fn col(&self, col: usize) -> *const F { self.ptr.add(col) } + + #[inline] + unsafe fn ptr(&self) -> *const F { + self.ptr + } + + #[inline] + unsafe fn inc_row(&mut self) { + self.ptr = self.row_mut(1); + } + + #[inline] + unsafe fn inc_col(&mut self) { + self.ptr = self.col_mut(1); + } + + #[inline] + unsafe fn shift_row(&mut self, rows: usize) { + self.ptr = self.row_mut(rows); + } + + #[inline] + unsafe fn shift_col(&mut self, cols: usize) { + self.ptr = self.col_mut(cols); + } } -impl MatrixMut for MutMatrix { +impl MatrixMut for MutMatrix { #[inline] - unsafe fn set(&self, row: usize, col: usize, value: f32) { + unsafe fn set(&self, row: usize, col: usize, value: F) { *self.ptr.add(row * self.stride + col) = value; } #[inline] - unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 { + unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F { self.ptr.add(row * self.stride + col) } #[inline] - unsafe fn row_mut(&self, row: usize) -> *mut f32 { + unsafe fn row_mut(&self, row: usize) -> *mut F { self.ptr.add(row * self.stride) } #[inline] - unsafe fn col_mut(&self, col: usize) -> *mut f32 { + unsafe fn col_mut(&self, col: usize) -> *mut F { self.ptr.add(col) } + + #[inline] + unsafe fn ptr_mut(&self) -> *mut F { + self.ptr + } } #[derive(Copy, Clone)] -pub(crate) struct MutTransposedMatrix { +pub struct MutTransposedMatrix { pub stride: usize, - pub ptr: *mut f32 + pub ptr: *mut F } -unsafe impl Send for MutTransposedMatrix {} -unsafe impl Sync for MutTransposedMatrix {} +unsafe impl Send for MutTransposedMatrix {} +unsafe impl Sync for MutTransposedMatrix {} -impl MutTransposedMatrix { - pub fn new(ptr: *mut f32, stride: usize) -> Self { +impl MutTransposedMatrix { + pub fn new(ptr: *mut F, stride: usize) -> Self { Self { ptr, stride } } } -impl Matrix for MutTransposedMatrix { +impl Matrix for MutTransposedMatrix { #[inline] unsafe fn sub(&self, row: usize, col: usize) -> Self { Self { @@ -226,6 +368,22 @@ impl Matrix for MutTransposedMatrix { } } + #[inline] + unsafe fn sub_col(&self, col: usize) -> Self { + Self { + ptr: self.col_mut(col), + stride: self.stride + } + } + + #[inline] + unsafe fn sub_row(&self, row: usize) -> Self { + Self { + ptr: self.row_mut(row), + stride: self.stride + } + } + #[inline] unsafe fn is_transposed(&self) -> bool { true @@ -237,47 +395,77 @@ impl Matrix for MutTransposedMatrix { } #[inline] - unsafe fn get(&self, row: usize, col: usize) -> f32 { + unsafe fn get(&self, row: usize, col: usize) -> F { *self.ptr.add(col * self.stride + row) } #[inline] - unsafe fn index(&self, row: usize, col: usize) -> *const f32 { - self.ptr.add(col * self.stride + row) as *const f32 + unsafe fn index(&self, row: usize, col: usize) -> *const F { + self.ptr.add(col * self.stride + row) as *const F } #[inline] - unsafe fn row(&self, row: usize) -> *const f32 { - self.ptr.add(row) as *const f32 + unsafe fn row(&self, row: usize) -> *const F { + self.ptr.add(row) as *const F } #[inline] - unsafe fn col(&self, col: usize) -> *const f32 { - self.ptr.add(col * self.stride) as *const f32 + unsafe fn col(&self, col: usize) -> *const F { + self.ptr.add(col * self.stride) as *const F + } + + #[inline] + unsafe fn ptr(&self) -> *const F { + self.ptr + } + + #[inline] + unsafe fn inc_row(&mut self) { + self.ptr = self.row_mut(1); + } + + #[inline] + unsafe fn inc_col(&mut self) { + self.ptr = self.col_mut(1); + } + + #[inline] + unsafe fn shift_row(&mut self, rows: usize) { + self.ptr = self.row_mut(rows); + } + + #[inline] + unsafe fn shift_col(&mut self, cols: usize) { + self.ptr = self.col_mut(cols); } } -impl MatrixMut for MutTransposedMatrix { +impl MatrixMut for MutTransposedMatrix { #[inline] - unsafe fn set(&self, row: usize, col: usize, value: f32) { + unsafe fn set(&self, row: usize, col: usize, value: F) { *self.ptr.add(col * self.stride + row) = value; } #[inline] - unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 { + unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F { self.ptr.add(col * self.stride + row) } #[inline] - unsafe fn row_mut(&self, row: usize) -> *mut f32 { + unsafe fn row_mut(&self, row: usize) -> *mut F { self.ptr.add(row) } #[inline] - unsafe fn col_mut(&self, col: usize) -> *mut f32 { + unsafe fn col_mut(&self, col: usize) -> *mut F { self.ptr.add(col * self.stride) } + + #[inline] + unsafe fn ptr_mut(&self) -> *mut F { + self.ptr + } } diff --git a/src/sgemm.rs b/src/sgemm.rs index 95f3fe2..8da14e2 100644 --- a/src/sgemm.rs +++ b/src/sgemm.rs @@ -1,273 +1,30 @@ -use crate::aligned_alloc; -use crate::matrix::{Matrix, MatrixMut, MutMatrix}; -use crate::kernel::params::single::*; -use crate::kernel; +use crate::kernel::avx::AvxKernel; +use crate::kernel::generic::GenericKernel; +use crate::dim::*; +use crate::executor::Executor; -pub(crate) unsafe fn sgemm( +pub unsafe fn sgemm( + e: &E, + transa: bool, + transb: bool, + transc: bool, m: usize, n: usize, k: usize, alpha: f32, - a: A, - b: B, + a: *const f32, + lda: usize, + b: *const f32, + ldb: usize, beta: f32, - c: C -) -where A: Matrix, - B: Matrix, - C: MatrixMut, -{ - let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::()); - let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::()); - - for j in (0..n).step_by(NC) { - let j_b = std::cmp::min(n - j, NC); - for p in (0..k).step_by(KC) { - let p_b = std::cmp::min(k - p, KC); - for i in (0..m).step_by(MC) { - let i_b = std::cmp::min(m - i, MC); - - let pa = MutMatrix::new(packed_a.ptr_f32(), p_b); - let pb = MutMatrix::new(packed_b.ptr_f32(), p_b); - - inner_kernel( - i_b, - j_b, - p_b, - alpha, - a.sub(p, i), - b.sub(j, p), - beta, - c.sub(j, i), - pa, - pb, - i == 0 - ); - } - } - } -} - - -// +----------------------+ -// |A : : : |B | -// | : : : | | -// + - - - - - - - - - ---+ -// | : : : | | -// | : : : | | -// + - - - - - - - - - ---+ -// | : : : | | -// | : : : | | -// +----------------------+ -// |C | | | |D | -// +----------------------+ - - -#[inline] -unsafe fn sgemm_pa(a: A, pa: MutMatrix, i: usize) { - if a.is_transposed() { - kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); - } else { - kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); - } -} - -#[inline] -unsafe fn sgemm_pb(b: B, pb: MutMatrix, j: usize) { - if b.is_transposed() { - kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); - } else { - kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); - } -} - -#[inline] -unsafe fn sgemm_ukr( - i: usize, j: usize, - alpha: f32, - pa: MutMatrix, - pb: MutMatrix, - beta: f32, - c: C) -{ - if c.is_transposed() { - unimplemented!() - } else { - kernel::sgemm_ukr( - pa.stride, - alpha, - pa.row(i), - pb.row(j), - beta, - c.index_mut(j, i), - c.stride()) - } -} - -#[inline] -unsafe fn sgemm_sup1( - i: usize, j: usize, - alpha: f32, - a: A, - pb: MutMatrix, - beta: f32, - c: C) -{ - if c.is_transposed() { - unimplemented!() - } else { - if a.is_transposed() { - kernel::sgemm_sup1_t( - pb.stride, - alpha, - a.col(i), - a.stride(), - pb.row(j), - beta, - c.index_mut(j, i), - c.stride()) - } else { - kernel::sgemm_sup1( - pb.stride, - alpha, - a.col(i), - a.stride(), - pb.row(j), - beta, - c.index_mut(j, i), - c.stride()) - } - } -} - -#[inline] -unsafe fn sgemm_sup0( - i: usize, j: usize, - alpha: f32, - pa: MutMatrix, - b: B, - beta: f32, - c: C) -{ - if c.is_transposed() { - unimplemented!() - } else { - if b.is_transposed() { - kernel::sgemm_sup0_t( - pa.stride, - alpha, - pa.row(i), - b.row(j), - b.stride(), - beta, - c.index_mut(j, i)) - } else { - kernel::sgemm_sup0( - pa.stride, - alpha, - pa.row(i), - b.row(j), - beta, - c.index_mut(j, i)) - } - } -} - -unsafe fn inner_kernel( - m: usize, - n: usize, - k: usize, - alpha: f32, - a: A, - b: B, - beta: f32, - c: C, - pa: MutMatrix, - pb: MutMatrix, - first_time: bool, -) - where A: Matrix, - B: Matrix, - C: MatrixMut, - -{ - let n_left = n % NR; - let n_main = n - n_left; - - let m_left = m % MR; - let m_main = m - m_left; - - if first_time { - for j in (0..n_main).step_by(NR) { - sgemm_pb(b, pb, j); - } - } - - for i in (0..m_main).step_by(MR) { - sgemm_pa(a, pa, i); - }; - - - for j in (0..n_main).step_by(NR) { - - // Section A - for i in (0..m_main).step_by(MR) { - sgemm_ukr( - i, - j, - alpha, - pa, - pb, - beta, - c, - ); - } - - // Section B - for i in m_main..m { - sgemm_sup1( - i, - j, - alpha, - a, - pb, - beta, - c, - ); - } - } - - for j in n_main..n { - - // Section C - for i in (0..m_main).step_by(MR) { - sgemm_sup0( - i, - j, - alpha, - pa, - b, - beta, - c, - ); - } - - // Section D - for i in m_main..m { - let mut elem = 0.0; - - for p in 0..k { - elem += a.get(p, i) * b.get(j, p); - } - - elem *= alpha; - - if beta != 0.0 { - elem += beta * c.get(j, i); - } - - c.set(j, i, elem); - } - } + c: *mut f32, + ldc: usize, +) { + crate::gemm::gemm::, A16, A5>( + e, + transa, transb, transc, + m, n, k, alpha, + a, lda, b, ldb, + beta, c, ldc); } \ No newline at end of file diff --git a/src/test/l3s.rs b/src/test/l3s.rs new file mode 100644 index 0000000..e621063 --- /dev/null +++ b/src/test/l3s.rs @@ -0,0 +1,308 @@ +use rand::Rng; + +const M_LEN: usize = 8192; +const N_LEN: usize = 8192; +const K_LEN: usize = 8192; + +fn make_matrices() -> (Vec, Vec, Vec, Vec) { + let mut rng = rand::thread_rng(); + + let (m, n, k) = (M_LEN, N_LEN, K_LEN); + + let mut a = vec![0.0; m * k]; + let mut a_t = vec![0.0; m * k]; + + let mut b = vec![0.0; n * k]; + let mut b_t = vec![0.0; n * k]; + + for row in 0..k { + for col in 0..m { + let v = rng.gen(); + a[row * m + col] = v; + a_t[col * k + row] = v; + } + } + + for row in 0..n { + for col in 0..k { + let v = rng.gen(); + b[row * k + col] = v; + b_t[col * n + row] = v; + } + } + + (a, a_t, b, b_t) +} + + +unsafe fn sgemm_ref( + m: usize, + n: usize, + k: usize, + _alpha: f32, + a: *const f32, + lda: usize, + b: *const f32, + ldb: usize, + _beta: f32, + c: *mut f32, + ldc: usize, +) { + for i_m in 0..m { + for i_k in 0..k { + let a_part = *a.add(i_m * lda + i_k); + + for i_n in 0..n { + *c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n); + } + } + } + + // for j in 0..n { + // for i in 0..m { + // let mut ci = *c.add(i + j * ldc); + + // for p in 0..k { + // ci += *a.add(i + p * lda) * *b.add(p + j * ldb); + // } + + // *c.add(i + j * ldc) = ci; + // } + // } +} + + +#[test] +fn test_sgemm_nn() { + let (m, n, k) = (M_LEN, N_LEN, K_LEN); + let (a, _, b, _) = make_matrices(); + + let mut c = vec![0.0; m * n]; + let mut cref = vec![0.0; m * n]; + + unsafe { + blas::sgemm( + b'N', + b'N', + m as i32, + n as i32, + k as i32, + 1.0, + a.as_slice(), + m as i32, + b.as_slice(), + k as i32, + 0.0, + cref.as_mut_slice(), + m as i32, + ) + } + + unsafe { + crate::sgemm( + &crate::executor::DefaultExecutor, + false, + false, + false, + m, + n, + k, + 1.0, + a.as_ptr(), + m, + b.as_ptr(), + k, + 0.0, + c.as_mut_ptr(), + m, + ); + } + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + + let (a, b) = (c[index], cref[index]); + + assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + } + } +} + + +#[test] +fn test_sgemm_nt() { + let (m, n, k) = (M_LEN, N_LEN, K_LEN); + let (a, _, b, b_t) = make_matrices(); + + let mut c = vec![0.0; m * n]; + let mut cref = vec![0.0; m * n]; + + unsafe { + blas::sgemm( + b'N', + b'T', + m as i32, + n as i32, + k as i32, + 1.0, + a.as_slice(), + m as i32, + b_t.as_slice(), + n as i32, + 0.0, + cref.as_mut_slice(), + m as i32, + ) + } + + unsafe { + crate::sgemm( + &crate::executor::DefaultExecutor, + false, + true, + false, + m, + n, + k, + 1.0, + a.as_ptr(), + m, + b_t.as_ptr(), + n, + 0.0, + c.as_mut_ptr(), + m, + ); + } + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + let (a, b) = (c[index], cref[index]); + assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + } + } +} + + +#[test] +fn test_sgemm_tn() { + let (m, n, k) = (M_LEN, N_LEN, K_LEN); + let (a, a_t, b, _) = make_matrices(); + + let mut c = vec![0.0; m * n]; + let mut cref = vec![0.0; m * n]; + + unsafe { + blas::sgemm( + b'T', + b'N', + m as i32, + n as i32, + k as i32, + 1.0, + a_t.as_slice(), + k as i32, + b.as_slice(), + k as i32, + 0.0, + cref.as_mut_slice(), + m as i32, + ) + } + + unsafe { + crate::sgemm( + &crate::executor::DefaultExecutor, + true, + false, + false, + m, + n, + k, + 1.0, + a_t.as_ptr(), + k, + b.as_ptr(), + k, + 0.0, + c.as_mut_ptr(), + m, + ); + } + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + let (a, b) = (c[index], cref[index]); + assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + } + } +} + + +#[test] +fn test_sgemm_tt() { + let (m, n, k) = (M_LEN, N_LEN, K_LEN); + let (a, a_t, b, b_t) = make_matrices(); + + let mut c = vec![0.0; m * n]; + let mut cref = vec![0.0; m * n]; + + unsafe { + blas::sgemm( + b'T', + b'T', + m as i32, + n as i32, + k as i32, + 1.0, + a_t.as_slice(), + k as i32, + b_t.as_slice(), + n as i32, + 0.0, + cref.as_mut_slice(), + m as i32, + ) + } + + unsafe { + crate::sgemm( + &crate::executor::DefaultExecutor, + true, + true, + false, + m, + n, + k, + 1.0, + a_t.as_ptr(), + k, + b_t.as_ptr(), + n, + 0.0, + c.as_mut_ptr(), + m, + ); + } + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + let (a, b) = (c[index], cref[index]); + assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); + } + } +} + +fn feq(a: f32, b: f32) -> bool { + if a == b { + true + } else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) { + (a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE + } else { + (a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0 + } +} \ No newline at end of file diff --git a/src/test/mod.rs b/src/test/mod.rs new file mode 100644 index 0000000..5b99a65 --- /dev/null +++ b/src/test/mod.rs @@ -0,0 +1 @@ +mod l3s; \ No newline at end of file diff --git a/stats.sh b/stats.sh index 12f7a20..93e0973 100755 --- a/stats.sh +++ b/stats.sh @@ -1,5 +1,5 @@ #!/bin/bash # export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2" export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2" -cargo clean && cargo build --release +#cargo clean && cargo build --release perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm