diff --git a/.gitignore b/.gitignore index 53eaa21..c402185 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target +/demo/target **/*.rs.bk diff --git a/demo/Cargo.lock b/demo/Cargo.lock new file mode 100644 index 0000000..c425c5d --- /dev/null +++ b/demo/Cargo.lock @@ -0,0 +1,314 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +[[package]] +name = "autocfg" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "blas" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", + "num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "blas-sys" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "c2-chacha" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "crossbeam-deque" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-queue" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-utils" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "demo" +version = "0.1.0" +dependencies = [ + "blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)", + "gemm 0.1.0", + "openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "either" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "gemm" +version = "0.1.0" +dependencies = [ + "blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", + "openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "getrandom" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", + "wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "hermit-abi" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "libc" +version = "0.2.66" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "memoffset" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-complex" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-traits" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num_cpus" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "openblas" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "ppv-lite86" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rand" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_chacha" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rayon" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rayon-core" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rustc_version" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "scopeguard" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "threadpool" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "wasi" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[metadata] +"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" +"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7" +"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440" +"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb" +"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca" +"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac" +"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700" +"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4" +"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407" +"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120" +"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +"checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558" +"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9" +"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc" +"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4" +"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" +"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8" +"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" +"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" +"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd" +"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791" +"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" +"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" +"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865" +"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d" diff --git a/demo/Cargo.toml b/demo/Cargo.toml new file mode 100644 index 0000000..05ed04e --- /dev/null +++ b/demo/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "demo" +version = "0.1.0" +authors = ["Andrey Tkachenko "] +edition = "2018" +build = "build.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +gemm = {path="../"} +openblas = "0.0.1" +blas = "0.20.0" +rand = "0.7.2" diff --git a/demo/build.rs b/demo/build.rs new file mode 100644 index 0000000..7a805fb --- /dev/null +++ b/demo/build.rs @@ -0,0 +1,7 @@ +fn main() { + // Always linking openblas + // Compiling blas just for testing is tedious -- install it on your system + // and run this. + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rustc-link-lib={}=openblas", "dylib"); +} \ No newline at end of file diff --git a/src/main.rs b/demo/src/main.rs similarity index 92% rename from src/main.rs rename to demo/src/main.rs index 3e1832d..bea7f9c 100644 --- a/src/main.rs +++ b/demo/src/main.rs @@ -1,16 +1,5 @@ -mod matrix; -mod gemm; -mod sgemm; -mod aligned_alloc; -mod kernel; -mod dim; -mod executor; - - -extern crate blas; extern crate openblas; - -use crate::sgemm::sgemm; +extern crate blas; fn main() { use rand::Rng; @@ -97,11 +86,12 @@ fn main() { let time = std::time::Instant::now(); unsafe { - sgemm( - &executor::RayonExecutor, + gemm::sgemm( + &gemm::executor::DefaultExecutor, false, false, false, + m, n, k, diff --git a/src/executor.rs b/src/executor.rs index 4ee1c70..9719e41 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -7,7 +7,7 @@ pub trait Executor { f: F, ); - fn synchronize(&self); + fn synchronize(&self) {} } pub struct DefaultExecutor; @@ -24,10 +24,6 @@ impl Executor for DefaultExecutor { f(i); } } - - fn synchronize(&self) { - - } } pub struct RayonExecutor; @@ -48,10 +44,6 @@ impl Executor for RayonExecutor { .map(|x| x * step) .for_each(|x|f(x)); } - - fn synchronize(&self) { - - } } diff --git a/src/gemm.rs b/src/gemm.rs index 9d7516e..77ecd92 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1,6 +1,6 @@ use crate::aligned_alloc; use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix}; -use crate::kernel::params::single::*; +use crate::kernel::params::single::{NC, MC, KC}; use crate::kernel; use crate::kernel::GemmKernel; use crate::dim::Dim; @@ -140,19 +140,25 @@ where E: Executor, } } -// | MR | -// +----------------------+ -// |TL : : : |TR| -// | : : : | | -// + - - - - - - - - - ---+---- -// | : : : | | NR -// | : : : | | -// + - - - - - - - - - ---+---- -// | : : : | | -// | : : : | | -// +----------------------+ -// |BL | | | |BR| -// +----------------------+ +// +// | MR | +// +-----------------------------+----+ +// | : | | +// | TL : | TR | +// | : | | +// + - - - - - - -:- - - - - - - - - -+ ---- +// | : | | +// | : | | NR +// | : | | +// + - - - - - - -:- - - - - - - - - -+ ---- +// | : | | +// | : | | +// | : | | +// +-----------------------------+----+ +// | BL | | BR | +// +-----------------------------+----+ +// + unsafe fn inner_kernel( e: &E, m: usize, @@ -177,25 +183,25 @@ unsafe fn inner_kernel( C: MatrixMut, { - let n_left = n % NR; + let n_left = n % NR::DIM; let n_main = n - n_left; - let m_left = m % MR; + let m_left = m % MR::DIM; let m_main = m - m_left; if first_time { - e.execute(0, n_main, NR, move |j| - K::pack_row_b(b, pb, j)); + e.execute(0, n_main, NR::DIM, move |j| + K::pack_row_b(b, pb)); } - e.execute(0, m_main, MR, move |i| - K::pack_row_a(a, pa, i)); + e.execute(0, m_main, MR::DIM, move |i| + K::pack_row_a(a.sub_col(i), pa.sub_row(i))); e.synchronize(); - e.execute(0, n_main, NR, move |j| { + e.execute(0, n_main, NR::DIM, move |j| { // Section TL - for i in (0..m_main).step_by(MR) { + for i in (0..m_main).step_by(MR::DIM) { K::main_tl(alpha, pa.sub_row(i), pb.sub_row(j), @@ -214,19 +220,22 @@ unsafe fn inner_kernel( } }); - e.execute(n_main, n, 1, move |j| { - // Section BL - for i in (0..m_main).step_by(MR) { - K::sup_bl( - alpha, - pa.sub_row(i), - b.sub_row(j), - beta, - c.sub(j, i) - ); - } + e.execute(0, n_left * m_main, MR::DIM, move |ji| { + let j = n_main + ji / m_main; + let i = ji % m_main; - // Section BR + // Section BL + K::sup_bl( + alpha, + pa.sub_row(i), + b.sub_row(j), + beta, + c.sub(j, i) + ); + }); + + // Section BR + for j in n_main..n { for i in m_main..m { K::sup_br( k, @@ -236,7 +245,7 @@ unsafe fn inner_kernel( beta, c.sub(j, i)) } - }); + }; e.synchronize(); } \ No newline at end of file diff --git a/src/kernel/avx/l3s.rs b/src/kernel/avx/l3s.rs index 3f0e1a5..b66759f 100644 --- a/src/kernel/avx/l3s.rs +++ b/src/kernel/avx/l3s.rs @@ -136,7 +136,19 @@ pub(crate) unsafe fn sgemm_ukr_16x8>( // let ccol5 = c.row_mut(5); // let ccol6 = c.row_mut(6); // let ccol7 = c.add(ldc * 7); - + + mt00 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt00); + mt01 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt01); + mt02 = _mm256_add_ps(_mm256_loadu_ps(ccol2), mt02); + mt03 = _mm256_add_ps(_mm256_loadu_ps(ccol3), mt03); + mt04 = _mm256_add_ps(_mm256_loadu_ps(ccol4), mt04); + + mt10 = _mm256_add_ps(_mm256_loadu_ps(ccol0.add(8)), mt10); + mt11 = _mm256_add_ps(_mm256_loadu_ps(ccol1.add(8)), mt11); + mt12 = _mm256_add_ps(_mm256_loadu_ps(ccol2.add(8)), mt12); + mt13 = _mm256_add_ps(_mm256_loadu_ps(ccol3.add(8)), mt13); + mt14 = _mm256_add_ps(_mm256_loadu_ps(ccol4.add(8)), mt14); + // if beta != 0.0 { // let beta = _mm256_broadcast_ss(&beta); @@ -210,6 +222,9 @@ pub(crate) unsafe fn sgemm_sup_16x1, C: MatrixMut>( // mt0 = _mm256_mul_ps(alpha, mt0); // mt1 = _mm256_mul_ps(alpha, mt1); + let ccol0 = c.ptr_mut(); + let ccol1 = c.ptr_mut().add(8); + // if beta != 0.0 { // let beta = _mm256_broadcast_ss(&beta); @@ -217,8 +232,11 @@ pub(crate) unsafe fn sgemm_sup_16x1, C: MatrixMut>( // mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1); // } - _mm256_storeu_ps(c.ptr_mut(), mt0); - _mm256_storeu_ps(c.col_mut(8), mt1); + mt0 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt0); + mt1 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt1); + + _mm256_storeu_ps(ccol0, mt0); + _mm256_storeu_ps(ccol1, mt1); } pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) { diff --git a/src/kernel/avx/mod.rs b/src/kernel/avx/mod.rs index 2e4c6f3..723b550 100644 --- a/src/kernel/avx/mod.rs +++ b/src/kernel/avx/mod.rs @@ -63,17 +63,17 @@ impl GemmKernel for AvxKernel where I: GemmKernel { #[inline] - unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { + unsafe fn pack_row_a>(a: A, pa: MutMatrix) { if a.is_transposed() { - I::pack_row_a(a, pa, i); + I::pack_row_a(a, pa); } else { - self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); + self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut()); } } #[inline] - unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { - I::pack_row_b(b, pb, j); + unsafe fn pack_row_b>(b: B, pb: MutMatrix) { + I::pack_row_b(b, pb); } #[inline] diff --git a/src/kernel/generic/l3s.rs b/src/kernel/generic/l3s.rs index 8dc7555..6f33b89 100644 --- a/src/kernel/generic/l3s.rs +++ b/src/kernel/generic/l3s.rs @@ -67,11 +67,11 @@ pub(crate) unsafe fn sgemm_sup_1x8, C: MatrixMut>( // // c7 += beta * *ccol7; // } - *ccol0 = c0; - *ccol1 = c1; - *ccol2 = c2; - *ccol3 = c3; - *ccol4 = c4; + *ccol0 += c0; + *ccol1 += c1; + *ccol2 += c2; + *ccol3 += c3; + *ccol4 += c4; // *ccol5 = c5; // *ccol6 = c6; // *ccol7 = c7; @@ -113,7 +113,6 @@ pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f } pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) { - use crate::kernel::params::single::NR; let mut b = b; let mut pb = pb; diff --git a/src/kernel/generic/mod.rs b/src/kernel/generic/mod.rs index ce7ee60..0f82c4f 100644 --- a/src/kernel/generic/mod.rs +++ b/src/kernel/generic/mod.rs @@ -61,27 +61,27 @@ impl GemmKernelSup for GenericKernel { elem += beta * *c.ptr(); } - *c.ptr_mut() = elem; + *c.ptr_mut() += elem; } } impl GemmKernel for GenericKernel { #[inline] - unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { + unsafe fn pack_row_a>(a: A, pa: MutMatrix) { if a.is_transposed() { - self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i)); + self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut()); } else { unimplemented!() } } #[inline] - unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { + unsafe fn pack_row_b>(b: B, pb: MutMatrix) { if b.is_transposed() { - self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); + self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut()); } else { - self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j)); + self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut()); } } diff --git a/src/kernel/mod.rs b/src/kernel/mod.rs index c12b9af..0c9615d 100644 --- a/src/kernel/mod.rs +++ b/src/kernel/mod.rs @@ -71,8 +71,8 @@ pub trait GemmKernelSupNr { } pub trait GemmKernel: GemmKernelSupMr + GemmKernelSupNr + GemmKernelSup { - unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize); - unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize); + unsafe fn pack_row_a>(a: A, pa: MutMatrix); + unsafe fn pack_row_b>(b: B, pb: MutMatrix); unsafe fn main_tl>( alpha: F, diff --git a/src/kernel/sse/mod.rs b/src/kernel/sse/mod.rs index dbc9170..2e464c6 100644 --- a/src/kernel/sse/mod.rs +++ b/src/kernel/sse/mod.rs @@ -64,13 +64,13 @@ impl GemmKernel for SseKernel where I: GemmKernel { #[inline] - unsafe fn pack_row_a>(a: A, pa: MutMatrix, i: usize) { - I::pack_row_a(a, pa, i); + unsafe fn pack_row_a>(a: A, pa: MutMatrix) { + I::pack_row_a(a, pa); } #[inline] - unsafe fn pack_row_b>(b: B, pb: MutMatrix, j: usize) { - I::pack_row_b(b, pb, j); + unsafe fn pack_row_b>(b: B, pb: MutMatrix) { + I::pack_row_b(b, pb); } #[inline] diff --git a/src/matrix.rs b/src/matrix.rs index 9ef97ea..4c91c14 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,5 +1,5 @@ -pub trait Number: Copy + Send + Sync + 'static {} +pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {} impl Number for f32 {} impl Number for f64 {} diff --git a/src/test/l3s.rs b/src/test/l3s.rs index e621063..4651529 100644 --- a/src/test/l3s.rs +++ b/src/test/l3s.rs @@ -1,8 +1,8 @@ use rand::Rng; -const M_LEN: usize = 8192; -const N_LEN: usize = 8192; -const K_LEN: usize = 8192; +const M_LEN: usize = 17; +const N_LEN: usize = 6; +const K_LEN: usize = 4; fn make_matrices() -> (Vec, Vec, Vec, Vec) { let mut rng = rand::thread_rng(); @@ -15,19 +15,23 @@ fn make_matrices() -> (Vec, Vec, Vec, Vec) { let mut b = vec![0.0; n * k]; let mut b_t = vec![0.0; n * k]; + let mut counter = 0; for row in 0..k { for col in 0..m { - let v = rng.gen(); + let v = counter as f32; a[row * m + col] = v; a_t[col * k + row] = v; + counter += 1; } } + let mut counter = 0; for row in 0..n { for col in 0..k { - let v = rng.gen(); + let v = counter as f32; b[row * k + col] = v; b_t[col * n + row] = v; + counter += 1; } } @@ -80,6 +84,17 @@ fn test_sgemm_nn() { let mut c = vec![0.0; m * n]; let mut cref = vec![0.0; m * n]; + for row in 0..N_LEN { + for col in 0..K_LEN { + let index = row * K_LEN + col; + + print!("{:>5} ", b[index]); + } + println!(); + } + + println!(); + unsafe { blas::sgemm( b'N', @@ -118,6 +133,27 @@ fn test_sgemm_nn() { ); } + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + + print!("{:>5} ", c[index]); + } + println!(); + } + + println!(); + + for row in 0..N_LEN { + for col in 0..M_LEN { + let index = row * M_LEN + col; + + print!("{:>5} ", cref[index]); + } + println!(); + } + for row in 0..N_LEN { for col in 0..M_LEN { let index = row * M_LEN + col; @@ -127,9 +163,9 @@ fn test_sgemm_nn() { assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b); } } + } - #[test] fn test_sgemm_nt() { let (m, n, k) = (M_LEN, N_LEN, K_LEN); diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index e69de29..0000000