Fixed bug with calculation
This commit is contained in:
parent
aaf8ec3203
commit
a699078f73
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
/target
|
||||
/demo/target
|
||||
**/*.rs.bk
|
||||
|
314
demo/Cargo.lock
generated
Normal file
314
demo/Cargo.lock
generated
Normal file
@ -0,0 +1,314 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "blas"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas-sys"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "c2-chacha"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "demo"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"gemm 0.1.0",
|
||||
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas"
|
||||
version = "0.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "threadpool"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[metadata]
|
||||
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
|
||||
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
|
||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
|
||||
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
|
||||
"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700"
|
||||
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
|
||||
"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
"checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558"
|
||||
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
|
||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
|
||||
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
|
||||
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd"
|
||||
"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||
"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865"
|
||||
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
14
demo/Cargo.toml
Normal file
14
demo/Cargo.toml
Normal file
@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "demo"
|
||||
version = "0.1.0"
|
||||
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
|
||||
edition = "2018"
|
||||
build = "build.rs"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
gemm = {path="../"}
|
||||
openblas = "0.0.1"
|
||||
blas = "0.20.0"
|
||||
rand = "0.7.2"
|
7
demo/build.rs
Normal file
7
demo/build.rs
Normal file
@ -0,0 +1,7 @@
|
||||
fn main() {
|
||||
// Always linking openblas
|
||||
// Compiling blas just for testing is tedious -- install it on your system
|
||||
// and run this.
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rustc-link-lib={}=openblas", "dylib");
|
||||
}
|
@ -1,16 +1,5 @@
|
||||
mod matrix;
|
||||
mod gemm;
|
||||
mod sgemm;
|
||||
mod aligned_alloc;
|
||||
mod kernel;
|
||||
mod dim;
|
||||
mod executor;
|
||||
|
||||
|
||||
extern crate blas;
|
||||
extern crate openblas;
|
||||
|
||||
use crate::sgemm::sgemm;
|
||||
extern crate blas;
|
||||
|
||||
fn main() {
|
||||
use rand::Rng;
|
||||
@ -97,11 +86,12 @@ fn main() {
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
&executor::RayonExecutor,
|
||||
gemm::sgemm(
|
||||
&gemm::executor::DefaultExecutor,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
|
||||
m,
|
||||
n,
|
||||
k,
|
@ -7,7 +7,7 @@ pub trait Executor {
|
||||
f: F,
|
||||
);
|
||||
|
||||
fn synchronize(&self);
|
||||
fn synchronize(&self) {}
|
||||
}
|
||||
|
||||
pub struct DefaultExecutor;
|
||||
@ -24,10 +24,6 @@ impl Executor for DefaultExecutor {
|
||||
f(i);
|
||||
}
|
||||
}
|
||||
|
||||
fn synchronize(&self) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RayonExecutor;
|
||||
@ -48,10 +44,6 @@ impl Executor for RayonExecutor {
|
||||
.map(|x| x * step)
|
||||
.for_each(|x|f(x));
|
||||
}
|
||||
|
||||
fn synchronize(&self) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
79
src/gemm.rs
79
src/gemm.rs
@ -1,6 +1,6 @@
|
||||
use crate::aligned_alloc;
|
||||
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||
use crate::kernel::params::single::*;
|
||||
use crate::kernel::params::single::{NC, MC, KC};
|
||||
use crate::kernel;
|
||||
use crate::kernel::GemmKernel;
|
||||
use crate::dim::Dim;
|
||||
@ -140,19 +140,25 @@ where E: Executor,
|
||||
}
|
||||
}
|
||||
|
||||
// | MR |
|
||||
// +----------------------+
|
||||
// |TL : : : |TR|
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+----
|
||||
// | : : : | | NR
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+----
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// +----------------------+
|
||||
// |BL | | | |BR|
|
||||
// +----------------------+
|
||||
//
|
||||
// | MR |
|
||||
// +-----------------------------+----+
|
||||
// | : | |
|
||||
// | TL : | TR |
|
||||
// | : | |
|
||||
// + - - - - - - -:- - - - - - - - - -+ ----
|
||||
// | : | |
|
||||
// | : | | NR
|
||||
// | : | |
|
||||
// + - - - - - - -:- - - - - - - - - -+ ----
|
||||
// | : | |
|
||||
// | : | |
|
||||
// | : | |
|
||||
// +-----------------------------+----+
|
||||
// | BL | | BR |
|
||||
// +-----------------------------+----+
|
||||
//
|
||||
|
||||
unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
e: &E,
|
||||
m: usize,
|
||||
@ -177,25 +183,25 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
C: MatrixMut<F>,
|
||||
|
||||
{
|
||||
let n_left = n % NR;
|
||||
let n_left = n % NR::DIM;
|
||||
let n_main = n - n_left;
|
||||
|
||||
let m_left = m % MR;
|
||||
let m_left = m % MR::DIM;
|
||||
let m_main = m - m_left;
|
||||
|
||||
if first_time {
|
||||
e.execute(0, n_main, NR, move |j|
|
||||
K::pack_row_b(b, pb, j));
|
||||
e.execute(0, n_main, NR::DIM, move |j|
|
||||
K::pack_row_b(b, pb));
|
||||
}
|
||||
|
||||
e.execute(0, m_main, MR, move |i|
|
||||
K::pack_row_a(a, pa, i));
|
||||
e.execute(0, m_main, MR::DIM, move |i|
|
||||
K::pack_row_a(a.sub_col(i), pa.sub_row(i)));
|
||||
|
||||
e.synchronize();
|
||||
|
||||
e.execute(0, n_main, NR, move |j| {
|
||||
e.execute(0, n_main, NR::DIM, move |j| {
|
||||
// Section TL
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
for i in (0..m_main).step_by(MR::DIM) {
|
||||
K::main_tl(alpha,
|
||||
pa.sub_row(i),
|
||||
pb.sub_row(j),
|
||||
@ -214,19 +220,22 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
}
|
||||
});
|
||||
|
||||
e.execute(n_main, n, 1, move |j| {
|
||||
// Section BL
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
K::sup_bl(
|
||||
alpha,
|
||||
pa.sub_row(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i)
|
||||
);
|
||||
}
|
||||
e.execute(0, n_left * m_main, MR::DIM, move |ji| {
|
||||
let j = n_main + ji / m_main;
|
||||
let i = ji % m_main;
|
||||
|
||||
// Section BR
|
||||
// Section BL
|
||||
K::sup_bl(
|
||||
alpha,
|
||||
pa.sub_row(i),
|
||||
b.sub_row(j),
|
||||
beta,
|
||||
c.sub(j, i)
|
||||
);
|
||||
});
|
||||
|
||||
// Section BR
|
||||
for j in n_main..n {
|
||||
for i in m_main..m {
|
||||
K::sup_br(
|
||||
k,
|
||||
@ -236,7 +245,7 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||
beta,
|
||||
c.sub(j, i))
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
e.synchronize();
|
||||
}
|
@ -136,7 +136,19 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||
// let ccol5 = c.row_mut(5);
|
||||
// let ccol6 = c.row_mut(6);
|
||||
// let ccol7 = c.add(ldc * 7);
|
||||
|
||||
|
||||
mt00 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt00);
|
||||
mt01 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt01);
|
||||
mt02 = _mm256_add_ps(_mm256_loadu_ps(ccol2), mt02);
|
||||
mt03 = _mm256_add_ps(_mm256_loadu_ps(ccol3), mt03);
|
||||
mt04 = _mm256_add_ps(_mm256_loadu_ps(ccol4), mt04);
|
||||
|
||||
mt10 = _mm256_add_ps(_mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||
mt11 = _mm256_add_ps(_mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||
mt12 = _mm256_add_ps(_mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||
mt13 = _mm256_add_ps(_mm256_loadu_ps(ccol3.add(8)), mt13);
|
||||
mt14 = _mm256_add_ps(_mm256_loadu_ps(ccol4.add(8)), mt14);
|
||||
|
||||
// if beta != 0.0 {
|
||||
// let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
@ -210,6 +222,9 @@ pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
// mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
// mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
|
||||
let ccol0 = c.ptr_mut();
|
||||
let ccol1 = c.ptr_mut().add(8);
|
||||
|
||||
// if beta != 0.0 {
|
||||
// let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
@ -217,8 +232,11 @@ pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
// mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
// }
|
||||
|
||||
_mm256_storeu_ps(c.ptr_mut(), mt0);
|
||||
_mm256_storeu_ps(c.col_mut(8), mt1);
|
||||
mt0 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt0);
|
||||
mt1 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt1);
|
||||
|
||||
_mm256_storeu_ps(ccol0, mt0);
|
||||
_mm256_storeu_ps(ccol1, mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
|
@ -63,17 +63,17 @@ impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
if a.is_transposed() {
|
||||
I::pack_row_a(a, pa, i);
|
||||
I::pack_row_a(a, pa);
|
||||
} else {
|
||||
self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
I::pack_row_b(b, pb, j);
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
|
||||
I::pack_row_b(b, pb);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -67,11 +67,11 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||
// // c7 += beta * *ccol7;
|
||||
// }
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
*ccol4 = c4;
|
||||
*ccol0 += c0;
|
||||
*ccol1 += c1;
|
||||
*ccol2 += c2;
|
||||
*ccol3 += c3;
|
||||
*ccol4 += c4;
|
||||
// *ccol5 = c5;
|
||||
// *ccol6 = c6;
|
||||
// *ccol7 = c7;
|
||||
@ -113,7 +113,6 @@ pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
use crate::kernel::params::single::NR;
|
||||
let mut b = b;
|
||||
let mut pb = pb;
|
||||
|
||||
|
@ -61,27 +61,27 @@ impl GemmKernelSup<f32> for GenericKernel {
|
||||
elem += beta * *c.ptr();
|
||||
}
|
||||
|
||||
*c.ptr_mut() = elem;
|
||||
*c.ptr_mut() += elem;
|
||||
}
|
||||
}
|
||||
|
||||
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
if a.is_transposed() {
|
||||
self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
|
||||
if b.is_transposed() {
|
||||
self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
||||
} else {
|
||||
self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,8 +71,8 @@ pub trait GemmKernelSupNr<F: Number, NR: Dim> {
|
||||
}
|
||||
|
||||
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
|
||||
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>, i: usize);
|
||||
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>, j: usize);
|
||||
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
|
||||
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
|
||||
|
||||
unsafe fn main_tl<C: MatrixMut<F>>(
|
||||
alpha: F,
|
||||
|
@ -64,13 +64,13 @@ impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
||||
where I: GemmKernel<f32, A16, A5>
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||
I::pack_row_a(a, pa, i);
|
||||
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
|
||||
I::pack_row_a(a, pa);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||
I::pack_row_b(b, pb, j);
|
||||
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
|
||||
I::pack_row_b(b, pb);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -1,5 +1,5 @@
|
||||
|
||||
pub trait Number: Copy + Send + Sync + 'static {}
|
||||
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
|
||||
|
||||
impl Number for f32 {}
|
||||
impl Number for f64 {}
|
||||
|
@ -1,8 +1,8 @@
|
||||
use rand::Rng;
|
||||
|
||||
const M_LEN: usize = 8192;
|
||||
const N_LEN: usize = 8192;
|
||||
const K_LEN: usize = 8192;
|
||||
const M_LEN: usize = 17;
|
||||
const N_LEN: usize = 6;
|
||||
const K_LEN: usize = 4;
|
||||
|
||||
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut rng = rand::thread_rng();
|
||||
@ -15,19 +15,23 @@ fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut b = vec![0.0; n * k];
|
||||
let mut b_t = vec![0.0; n * k];
|
||||
|
||||
let mut counter = 0;
|
||||
for row in 0..k {
|
||||
for col in 0..m {
|
||||
let v = rng.gen();
|
||||
let v = counter as f32;
|
||||
a[row * m + col] = v;
|
||||
a_t[col * k + row] = v;
|
||||
counter += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut counter = 0;
|
||||
for row in 0..n {
|
||||
for col in 0..k {
|
||||
let v = rng.gen();
|
||||
let v = counter as f32;
|
||||
b[row * k + col] = v;
|
||||
b_t[col * n + row] = v;
|
||||
counter += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,6 +84,17 @@ fn test_sgemm_nn() {
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..K_LEN {
|
||||
let index = row * K_LEN + col;
|
||||
|
||||
print!("{:>5} ", b[index]);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
println!();
|
||||
|
||||
unsafe {
|
||||
blas::sgemm(
|
||||
b'N',
|
||||
@ -118,6 +133,27 @@ fn test_sgemm_nn() {
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
|
||||
print!("{:>5} ", c[index]);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
println!();
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
|
||||
print!("{:>5} ", cref[index]);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
for row in 0..N_LEN {
|
||||
for col in 0..M_LEN {
|
||||
let index = row * M_LEN + col;
|
||||
@ -127,9 +163,9 @@ fn test_sgemm_nn() {
|
||||
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sgemm_nt() {
|
||||
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||
|
Loading…
Reference in New Issue
Block a user