Generalization + SSE
This commit is contained in:
parent
ffae99bafd
commit
aaf8ec3203
156
Cargo.lock
generated
156
Cargo.lock
generated
@ -36,14 +36,66 @@ name = "cfg-if"
|
|||||||
version = "0.1.10"
|
version = "0.1.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-queue"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crunchy"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gemm"
|
name = "gemm"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
"crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -56,17 +108,30 @@ dependencies = [
|
|||||||
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy_static"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.65"
|
version = "0.2.65"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matrixmultiply"
|
name = "memoffset"
|
||||||
version = "0.2.3"
|
version = "0.5.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -86,6 +151,15 @@ dependencies = [
|
|||||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num_cpus"
|
||||||
|
version = "1.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openblas"
|
name = "openblas"
|
||||||
version = "0.0.1"
|
version = "0.0.1"
|
||||||
@ -134,9 +208,60 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rawpointer"
|
name = "rayon"
|
||||||
version = "0.2.1"
|
version = "1.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-core"
|
||||||
|
version = "1.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc_version"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "semver"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "semver-parser"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "threadpool"
|
||||||
|
version = "1.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
@ -149,16 +274,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||||
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
||||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||||
|
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
|
||||||
|
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
|
||||||
|
"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700"
|
||||||
|
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
|
||||||
|
"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||||
|
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||||
|
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||||
|
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
||||||
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
|
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
|
||||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||||
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
|
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
|
||||||
|
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
|
||||||
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
|
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
|
||||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||||
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
||||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||||
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd"
|
||||||
|
"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791"
|
||||||
|
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||||
|
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||||
|
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||||
|
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||||
|
"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865"
|
||||||
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
||||||
|
@ -9,6 +9,11 @@ build = "build.rs"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rand = "0.7.2"
|
rand = "0.7.2"
|
||||||
matrixmultiply = "0.2.3"
|
|
||||||
blas = "0.20.0"
|
blas = "0.20.0"
|
||||||
openblas = "0.0.1"
|
openblas = "0.0.1"
|
||||||
|
crunchy = "0.2.2"
|
||||||
|
rayon = "1.2.1"
|
||||||
|
threadpool = "1.7.1"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
debug = true
|
||||||
|
@ -13,12 +13,8 @@ impl Alloc {
|
|||||||
Alloc { ptr, layout }
|
Alloc { ptr, layout }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ptr_f32(&self) -> *mut f32 {
|
pub fn ptr<F>(&self) -> *mut F {
|
||||||
self.ptr as *mut f32
|
self.ptr as *mut F
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ptr(&self) -> *mut u8 {
|
|
||||||
self.ptr
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
75
src/dim.rs
Normal file
75
src/dim.rs
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
pub trait Dim {
|
||||||
|
const DIM: usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct A1;
|
||||||
|
impl Dim for A1 {
|
||||||
|
const DIM: usize = 1;
|
||||||
|
}
|
||||||
|
pub struct A2;
|
||||||
|
impl Dim for A2 {
|
||||||
|
const DIM: usize = 2;
|
||||||
|
}
|
||||||
|
pub struct A3;
|
||||||
|
impl Dim for A3 {
|
||||||
|
const DIM: usize = 3;
|
||||||
|
}
|
||||||
|
pub struct A4;
|
||||||
|
impl Dim for A4 {
|
||||||
|
const DIM: usize = 4;
|
||||||
|
}
|
||||||
|
pub struct A5;
|
||||||
|
impl Dim for A5 {
|
||||||
|
const DIM: usize = 5;
|
||||||
|
}
|
||||||
|
pub struct A6;
|
||||||
|
impl Dim for A6 {
|
||||||
|
const DIM: usize = 6;
|
||||||
|
}
|
||||||
|
pub struct A7;
|
||||||
|
impl Dim for A7 {
|
||||||
|
const DIM: usize = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct A8;
|
||||||
|
impl Dim for A8 {
|
||||||
|
const DIM: usize = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct A16;
|
||||||
|
impl Dim for A16 {
|
||||||
|
const DIM: usize = 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
macro_rules! gen {
|
||||||
|
(($name: ident, $val: lit)) => {
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
gen! {
|
||||||
|
(A0, 0),
|
||||||
|
(A1, 1),
|
||||||
|
(A2, 2),
|
||||||
|
(A3, 3),
|
||||||
|
(A4, 4),
|
||||||
|
(A5, 5),
|
||||||
|
(A6, 6),
|
||||||
|
(A7, 7),
|
||||||
|
(A8, 8),
|
||||||
|
(A9, 9),
|
||||||
|
(A10, 10),
|
||||||
|
(A11, 11),
|
||||||
|
(A12, 12),
|
||||||
|
(A14, 14),
|
||||||
|
(A16, 16),
|
||||||
|
(A18, 18),
|
||||||
|
(A20, 20),
|
||||||
|
(A22, 22),
|
||||||
|
(A24, 24),
|
||||||
|
(A26, 26),
|
||||||
|
(A28, 28),
|
||||||
|
(A30, 30),
|
||||||
|
(A32, 32),
|
||||||
|
}*/
|
111
src/executor.rs
Normal file
111
src/executor.rs
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
pub trait Executor {
|
||||||
|
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
step: usize,
|
||||||
|
f: F,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn synchronize(&self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DefaultExecutor;
|
||||||
|
impl Executor for DefaultExecutor {
|
||||||
|
#[inline]
|
||||||
|
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
step: usize,
|
||||||
|
f: F,
|
||||||
|
) {
|
||||||
|
for i in (start..end).step_by(step) {
|
||||||
|
f(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RayonExecutor;
|
||||||
|
impl Executor for RayonExecutor {
|
||||||
|
#[inline]
|
||||||
|
fn execute<F: FnOnce(usize) + Send + 'static + Copy + Sync>(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
step: usize,
|
||||||
|
f: F,
|
||||||
|
) {
|
||||||
|
let end = 1 + ((end - 1) / step);
|
||||||
|
|
||||||
|
use rayon::prelude::*;
|
||||||
|
(start..end)
|
||||||
|
.into_par_iter()
|
||||||
|
.map(|x| x * step)
|
||||||
|
.for_each(|x|f(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn synchronize(&self) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub struct ThreadPoolExecutor {
|
||||||
|
thread_pool: threadpool::ThreadPool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadPoolExecutor {
|
||||||
|
pub fn new() -> ThreadPoolExecutor {
|
||||||
|
ThreadPoolExecutor {
|
||||||
|
thread_pool: threadpool::Builder::new().build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Executor for ThreadPoolExecutor {
|
||||||
|
#[inline]
|
||||||
|
fn execute<F: FnOnce(usize) + Send + 'static + Copy>(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
step: usize,
|
||||||
|
f: F,
|
||||||
|
) {
|
||||||
|
let thread_count = self.thread_pool.max_count();
|
||||||
|
|
||||||
|
let len = end - start;
|
||||||
|
let num_steps = len / step;
|
||||||
|
|
||||||
|
let mut left_steps = num_steps % thread_count;
|
||||||
|
let main_steps = num_steps - left_steps;
|
||||||
|
|
||||||
|
let job_size = main_steps / thread_count;
|
||||||
|
|
||||||
|
let mut prev_end = 0;
|
||||||
|
|
||||||
|
for _ in 0..thread_count {
|
||||||
|
let mut now_end = prev_end + job_size;
|
||||||
|
if left_steps > 0 {
|
||||||
|
now_end += 1;
|
||||||
|
left_steps -= 1;
|
||||||
|
}
|
||||||
|
self.thread_pool.execute(move || {
|
||||||
|
for j in prev_end..now_end {
|
||||||
|
f(start + j * step);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
prev_end = now_end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn synchronize(&self) {
|
||||||
|
self.thread_pool.join();
|
||||||
|
}
|
||||||
|
}
|
242
src/gemm.rs
Normal file
242
src/gemm.rs
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
use crate::aligned_alloc;
|
||||||
|
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||||
|
use crate::kernel::params::single::*;
|
||||||
|
use crate::kernel;
|
||||||
|
use crate::kernel::GemmKernel;
|
||||||
|
use crate::dim::Dim;
|
||||||
|
use crate::executor::Executor;
|
||||||
|
|
||||||
|
pub unsafe fn gemm<E, F, K, MR, NR>(
|
||||||
|
e: &E,
|
||||||
|
transa: bool,
|
||||||
|
transb: bool,
|
||||||
|
transc: bool,
|
||||||
|
m: usize,
|
||||||
|
n: usize,
|
||||||
|
k: usize,
|
||||||
|
alpha: F,
|
||||||
|
a: *const F,
|
||||||
|
lda: usize,
|
||||||
|
b: *const F,
|
||||||
|
ldb: usize,
|
||||||
|
beta: F,
|
||||||
|
c: *mut F,
|
||||||
|
ldc: usize,
|
||||||
|
)
|
||||||
|
where E: Executor,
|
||||||
|
F: Number,
|
||||||
|
MR: Dim, NR: Dim,
|
||||||
|
K: GemmKernel<F, MR, NR>,
|
||||||
|
{
|
||||||
|
match (transa, transb, transc) {
|
||||||
|
(false, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstMatrix::new(a, lda),
|
||||||
|
ConstMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
(false, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstMatrix::new(a, lda),
|
||||||
|
ConstMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutTransposedMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
(false, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstMatrix::new(a, lda),
|
||||||
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
(false, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstMatrix::new(a, lda),
|
||||||
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutTransposedMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
(true, false, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstTransposedMatrix::new(a, lda),
|
||||||
|
ConstMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
(true, false, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstTransposedMatrix::new(a, lda),
|
||||||
|
ConstMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutTransposedMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
|
||||||
|
(true, true, false) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstTransposedMatrix::new(a, lda),
|
||||||
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutMatrix::new(c, ldc)),
|
||||||
|
|
||||||
|
|
||||||
|
(true, true, true) => gemm_template::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e, m, n, k, alpha,
|
||||||
|
ConstTransposedMatrix::new(a, lda),
|
||||||
|
ConstTransposedMatrix::new(b, ldb),
|
||||||
|
beta,
|
||||||
|
MutTransposedMatrix::new(c, ldc)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn gemm_template<E, F, K, MR, NR, A, B, C>(
|
||||||
|
e: &E,
|
||||||
|
m: usize,
|
||||||
|
n: usize,
|
||||||
|
k: usize,
|
||||||
|
alpha: F,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: F,
|
||||||
|
c: C
|
||||||
|
)
|
||||||
|
where E: Executor,
|
||||||
|
F: Number,
|
||||||
|
MR: Dim, NR: Dim,
|
||||||
|
K: GemmKernel<F, MR, NR>,
|
||||||
|
A: Matrix<F>,
|
||||||
|
B: Matrix<F>,
|
||||||
|
C: MatrixMut<F>,
|
||||||
|
{
|
||||||
|
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<F>());
|
||||||
|
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<F>());
|
||||||
|
|
||||||
|
for j in (0..n).step_by(NC) {
|
||||||
|
let j_b = std::cmp::min(n - j, NC);
|
||||||
|
for p in (0..k).step_by(KC) {
|
||||||
|
let p_b = std::cmp::min(k - p, KC);
|
||||||
|
for i in (0..m).step_by(MC) {
|
||||||
|
let i_b = std::cmp::min(m - i, MC);
|
||||||
|
|
||||||
|
let pa = MutMatrix::new(packed_a.ptr::<F>(), p_b);
|
||||||
|
let pb = MutMatrix::new(packed_b.ptr::<F>(), p_b);
|
||||||
|
|
||||||
|
inner_kernel::<E, F, K, MR, NR, _, _, _>(
|
||||||
|
e,
|
||||||
|
i_b,
|
||||||
|
j_b,
|
||||||
|
p_b,
|
||||||
|
alpha,
|
||||||
|
a.sub(p, i),
|
||||||
|
b.sub(j, p),
|
||||||
|
beta,
|
||||||
|
c.sub(j, i),
|
||||||
|
pa,
|
||||||
|
pb,
|
||||||
|
i == 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// | MR |
|
||||||
|
// +----------------------+
|
||||||
|
// |TL : : : |TR|
|
||||||
|
// | : : : | |
|
||||||
|
// + - - - - - - - - - ---+----
|
||||||
|
// | : : : | | NR
|
||||||
|
// | : : : | |
|
||||||
|
// + - - - - - - - - - ---+----
|
||||||
|
// | : : : | |
|
||||||
|
// | : : : | |
|
||||||
|
// +----------------------+
|
||||||
|
// |BL | | | |BR|
|
||||||
|
// +----------------------+
|
||||||
|
unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
|
||||||
|
e: &E,
|
||||||
|
m: usize,
|
||||||
|
n: usize,
|
||||||
|
k: usize,
|
||||||
|
alpha: F,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: F,
|
||||||
|
c: C,
|
||||||
|
pa: MutMatrix<F>,
|
||||||
|
pb: MutMatrix<F>,
|
||||||
|
first_time: bool,
|
||||||
|
)
|
||||||
|
where E: Executor,
|
||||||
|
F: Number,
|
||||||
|
MR: Dim,
|
||||||
|
NR: Dim,
|
||||||
|
K: kernel::GemmKernel<F, MR, NR>,
|
||||||
|
A: Matrix<F>,
|
||||||
|
B: Matrix<F>,
|
||||||
|
C: MatrixMut<F>,
|
||||||
|
|
||||||
|
{
|
||||||
|
let n_left = n % NR;
|
||||||
|
let n_main = n - n_left;
|
||||||
|
|
||||||
|
let m_left = m % MR;
|
||||||
|
let m_main = m - m_left;
|
||||||
|
|
||||||
|
if first_time {
|
||||||
|
e.execute(0, n_main, NR, move |j|
|
||||||
|
K::pack_row_b(b, pb, j));
|
||||||
|
}
|
||||||
|
|
||||||
|
e.execute(0, m_main, MR, move |i|
|
||||||
|
K::pack_row_a(a, pa, i));
|
||||||
|
|
||||||
|
e.synchronize();
|
||||||
|
|
||||||
|
e.execute(0, n_main, NR, move |j| {
|
||||||
|
// Section TL
|
||||||
|
for i in (0..m_main).step_by(MR) {
|
||||||
|
K::main_tl(alpha,
|
||||||
|
pa.sub_row(i),
|
||||||
|
pb.sub_row(j),
|
||||||
|
beta,
|
||||||
|
c.sub(j, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section TR
|
||||||
|
for i in m_main..m {
|
||||||
|
K::sup_tr(
|
||||||
|
alpha,
|
||||||
|
a.sub_col(i),
|
||||||
|
pb.sub_row(j),
|
||||||
|
beta,
|
||||||
|
c.sub(j, i));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
e.execute(n_main, n, 1, move |j| {
|
||||||
|
// Section BL
|
||||||
|
for i in (0..m_main).step_by(MR) {
|
||||||
|
K::sup_bl(
|
||||||
|
alpha,
|
||||||
|
pa.sub_row(i),
|
||||||
|
b.sub_row(j),
|
||||||
|
beta,
|
||||||
|
c.sub(j, i)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section BR
|
||||||
|
for i in m_main..m {
|
||||||
|
K::sup_br(
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
a.sub_col(i),
|
||||||
|
b.sub_row(j),
|
||||||
|
beta,
|
||||||
|
c.sub(j, i))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
e.synchronize();
|
||||||
|
}
|
@ -246,6 +246,7 @@ pub unsafe fn sdot(
|
|||||||
let mut acc1 = _mm256_setzero_ps();
|
let mut acc1 = _mm256_setzero_ps();
|
||||||
let mut acc2 = _mm256_setzero_ps();
|
let mut acc2 = _mm256_setzero_ps();
|
||||||
let mut acc3 = _mm256_setzero_ps();
|
let mut acc3 = _mm256_setzero_ps();
|
||||||
|
|
||||||
for _ in 0..n / 32 {
|
for _ in 0..n / 32 {
|
||||||
let x0 = _mm256_loadu_ps(x);
|
let x0 = _mm256_loadu_ps(x);
|
||||||
let y0 = _mm256_loadu_ps(y);
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
@ -1,29 +1,82 @@
|
|||||||
use super::fma::fmadd_ps;
|
use super::fma::fmadd_ps;
|
||||||
use super::intrinsics::*;
|
use super::intrinsics::*;
|
||||||
|
use crate::kernel::params::single::{MR, NR};
|
||||||
|
use crate::matrix::{Matrix, MutMatrix, MatrixMut};
|
||||||
|
use crunchy::unroll;
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_ukr_16x4(
|
#[inline]
|
||||||
|
pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
pa: *const f32,
|
pa: MutMatrix<f32>,
|
||||||
pb: *const f32,
|
pb: MutMatrix<f32>,
|
||||||
beta: f32,
|
beta: f32,
|
||||||
c: *mut f32,
|
c: C,
|
||||||
ldc: usize,
|
|
||||||
) {
|
) {
|
||||||
let mut mt00 = _mm256_setzero_ps();
|
let mut mt00 = _mm256_setzero_ps();
|
||||||
let mut mt01 = _mm256_setzero_ps();
|
let mut mt01 = _mm256_setzero_ps();
|
||||||
let mut mt02 = _mm256_setzero_ps();
|
let mut mt02 = _mm256_setzero_ps();
|
||||||
let mut mt03 = _mm256_setzero_ps();
|
let mut mt03 = _mm256_setzero_ps();
|
||||||
|
let mut mt04 = _mm256_setzero_ps();
|
||||||
|
// let mut mt05 = _mm256_setzero_ps();
|
||||||
|
// let mut mt06 = _mm256_setzero_ps();
|
||||||
|
// let mut mt07 = _mm256_setzero_ps();
|
||||||
|
|
||||||
let mut mt10 = _mm256_setzero_ps();
|
let mut mt10 = _mm256_setzero_ps();
|
||||||
let mut mt11 = _mm256_setzero_ps();
|
let mut mt11 = _mm256_setzero_ps();
|
||||||
let mut mt12 = _mm256_setzero_ps();
|
let mut mt12 = _mm256_setzero_ps();
|
||||||
let mut mt13 = _mm256_setzero_ps();
|
let mut mt13 = _mm256_setzero_ps();
|
||||||
|
let mut mt14 = _mm256_setzero_ps();
|
||||||
|
// let mut mt15 = _mm256_setzero_ps();
|
||||||
|
// let mut mt16 = _mm256_setzero_ps();
|
||||||
|
// let mut mt17 = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
let mut pa = pa.ptr();
|
||||||
|
let mut pb = pb.ptr();
|
||||||
|
|
||||||
let mut pa = pa;
|
const BATCH: usize = 16;
|
||||||
let mut pb = pb;
|
|
||||||
|
|
||||||
for _ in 0..k {
|
let k_right = k % BATCH;
|
||||||
|
let k_main = k - k_right;
|
||||||
|
|
||||||
|
for _ in (0..k_main).step_by(BATCH) {
|
||||||
|
unroll! {
|
||||||
|
for i in 0..16 {
|
||||||
|
let a0 = _mm256_load_ps(pa.add(i * MR));
|
||||||
|
let a1 = _mm256_load_ps(pa.add(i * MR + 8));
|
||||||
|
|
||||||
|
let b0 = _mm256_broadcast_ss(&*pb.add(i * NR));
|
||||||
|
let b1 = _mm256_broadcast_ss(&*pb.add(i * NR + 1));
|
||||||
|
let b2 = _mm256_broadcast_ss(&*pb.add(i * NR + 2));
|
||||||
|
let b3 = _mm256_broadcast_ss(&*pb.add(i * NR + 3));
|
||||||
|
let b4 = _mm256_broadcast_ss(&*pb.add(i * NR + 4));
|
||||||
|
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||||
|
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||||
|
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||||
|
|
||||||
|
mt00 = fmadd_ps(a0, b0, mt00);
|
||||||
|
mt01 = fmadd_ps(a0, b1, mt01);
|
||||||
|
mt02 = fmadd_ps(a0, b2, mt02);
|
||||||
|
mt03 = fmadd_ps(a0, b3, mt03);
|
||||||
|
mt04 = fmadd_ps(a0, b4, mt04);
|
||||||
|
// mt05 = fmadd_ps(a0, b5, mt05);
|
||||||
|
// mt06 = fmadd_ps(a0, b6, mt06);
|
||||||
|
// mt07 = fmadd_ps(a0, b7, mt07);
|
||||||
|
|
||||||
|
mt10 = fmadd_ps(a1, b0, mt10);
|
||||||
|
mt11 = fmadd_ps(a1, b1, mt11);
|
||||||
|
mt12 = fmadd_ps(a1, b2, mt12);
|
||||||
|
mt13 = fmadd_ps(a1, b3, mt13);
|
||||||
|
mt14 = fmadd_ps(a1, b4, mt14);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pa = pa.add(BATCH * MR);
|
||||||
|
pb = pb.add(BATCH * NR);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for _ in k_main..k {
|
||||||
let a0 = _mm256_load_ps(pa);
|
let a0 = _mm256_load_ps(pa);
|
||||||
let a1 = _mm256_load_ps(pa.add(8));
|
let a1 = _mm256_load_ps(pa.add(8));
|
||||||
|
|
||||||
@ -31,66 +84,107 @@ pub(crate) unsafe fn sgemm_ukr_16x4(
|
|||||||
let b1 = _mm256_broadcast_ss(&*pb.add(1));
|
let b1 = _mm256_broadcast_ss(&*pb.add(1));
|
||||||
let b2 = _mm256_broadcast_ss(&*pb.add(2));
|
let b2 = _mm256_broadcast_ss(&*pb.add(2));
|
||||||
let b3 = _mm256_broadcast_ss(&*pb.add(3));
|
let b3 = _mm256_broadcast_ss(&*pb.add(3));
|
||||||
|
let b4 = _mm256_broadcast_ss(&*pb.add(4));
|
||||||
|
// let b5 = _mm256_broadcast_ss(&*pb.add(5));
|
||||||
|
// let b6 = _mm256_broadcast_ss(&*pb.add(6));
|
||||||
|
// let b7 = _mm256_broadcast_ss(&*pb.add(7));
|
||||||
|
|
||||||
mt00 = fmadd_ps(a0, b0, mt00);
|
mt00 = fmadd_ps(a0, b0, mt00);
|
||||||
mt01 = fmadd_ps(a0, b1, mt01);
|
mt01 = fmadd_ps(a0, b1, mt01);
|
||||||
mt02 = fmadd_ps(a0, b2, mt02);
|
mt02 = fmadd_ps(a0, b2, mt02);
|
||||||
mt03 = fmadd_ps(a0, b3, mt03);
|
mt03 = fmadd_ps(a0, b3, mt03);
|
||||||
|
mt04 = fmadd_ps(a0, b4, mt04);
|
||||||
|
// mt05 = fmadd_ps(a0, b5, mt05);
|
||||||
|
// mt06 = fmadd_ps(a0, b6, mt06);
|
||||||
|
// mt07 = fmadd_ps(a0, b7, mt07);
|
||||||
|
|
||||||
mt10 = fmadd_ps(a1, b0, mt10);
|
mt10 = fmadd_ps(a1, b0, mt10);
|
||||||
mt11 = fmadd_ps(a1, b1, mt11);
|
mt11 = fmadd_ps(a1, b1, mt11);
|
||||||
mt12 = fmadd_ps(a1, b2, mt12);
|
mt12 = fmadd_ps(a1, b2, mt12);
|
||||||
mt13 = fmadd_ps(a1, b3, mt13);
|
mt13 = fmadd_ps(a1, b3, mt13);
|
||||||
|
mt14 = fmadd_ps(a1, b4, mt14);
|
||||||
|
|
||||||
pa = pa.add(16);
|
pa = pa.add(MR);
|
||||||
pb = pb.add(4);
|
pb = pb.add(NR);
|
||||||
}
|
}
|
||||||
|
|
||||||
let alpha = _mm256_broadcast_ss(&alpha);
|
// let alpha = _mm256_broadcast_ss(&alpha);
|
||||||
|
|
||||||
mt00 = _mm256_mul_ps(alpha, mt00);
|
// mt00 = _mm256_mul_ps(alpha, mt00);
|
||||||
mt01 = _mm256_mul_ps(alpha, mt01);
|
// mt01 = _mm256_mul_ps(alpha, mt01);
|
||||||
mt02 = _mm256_mul_ps(alpha, mt02);
|
// mt02 = _mm256_mul_ps(alpha, mt02);
|
||||||
mt03 = _mm256_mul_ps(alpha, mt03);
|
// mt03 = _mm256_mul_ps(alpha, mt03);
|
||||||
mt10 = _mm256_mul_ps(alpha, mt10);
|
// mt04 = _mm256_mul_ps(alpha, mt04);
|
||||||
mt11 = _mm256_mul_ps(alpha, mt11);
|
// // mt05 = _mm256_mul_ps(alpha, mt05);
|
||||||
mt12 = _mm256_mul_ps(alpha, mt12);
|
// // mt06 = _mm256_mul_ps(alpha, mt06);
|
||||||
mt13 = _mm256_mul_ps(alpha, mt13);
|
// // mt07 = _mm256_mul_ps(alpha, mt07);
|
||||||
|
|
||||||
let ccol0 = c;
|
// mt10 = _mm256_mul_ps(alpha, mt10);
|
||||||
let ccol1 = c.add(ldc);
|
// mt11 = _mm256_mul_ps(alpha, mt11);
|
||||||
let ccol2 = c.add(ldc * 2);
|
// mt12 = _mm256_mul_ps(alpha, mt12);
|
||||||
let ccol3 = c.add(ldc * 3);
|
// mt13 = _mm256_mul_ps(alpha, mt13);
|
||||||
|
// mt14 = _mm256_mul_ps(alpha, mt14);
|
||||||
|
// // mt15 = _mm256_mul_ps(alpha, mt15);
|
||||||
|
// // mt16 = _mm256_mul_ps(alpha, mt16);
|
||||||
|
// // mt17 = _mm256_mul_ps(alpha, mt17);
|
||||||
|
|
||||||
if beta != 0.0 {
|
let ccol0 = c.ptr_mut();
|
||||||
let beta = _mm256_broadcast_ss(&beta);
|
let ccol1 = c.row_mut(1);
|
||||||
|
let ccol2 = c.row_mut(2);
|
||||||
|
let ccol3 = c.row_mut(3);
|
||||||
|
let ccol4 = c.row_mut(4);
|
||||||
|
// let ccol5 = c.row_mut(5);
|
||||||
|
// let ccol6 = c.row_mut(6);
|
||||||
|
// let ccol7 = c.add(ldc * 7);
|
||||||
|
|
||||||
|
// if beta != 0.0 {
|
||||||
|
// let beta = _mm256_broadcast_ss(&beta);
|
||||||
|
|
||||||
mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
|
// mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
|
||||||
mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
|
// mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
|
||||||
mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
|
// mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
|
||||||
mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
|
// mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
|
||||||
mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
// mt04 = fmadd_ps(beta, _mm256_loadu_ps(ccol4), mt04);
|
||||||
mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
// // mt05 = fmadd_ps(beta, _mm256_loadu_ps(ccol5), mt05);
|
||||||
mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
// // mt06 = fmadd_ps(beta, _mm256_loadu_ps(ccol6), mt06);
|
||||||
mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
|
// // mt07 = fmadd_ps(beta, _mm256_loadu_ps(ccol7), mt07);
|
||||||
}
|
|
||||||
|
// mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||||
|
// mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||||
|
// mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||||
|
// mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
|
||||||
|
// mt14 = fmadd_ps(beta, _mm256_loadu_ps(ccol4.add(8)), mt14);
|
||||||
|
// // mt15 = fmadd_ps(beta, _mm256_loadu_ps(ccol5.add(8)), mt15);
|
||||||
|
// // mt16 = fmadd_ps(beta, _mm256_loadu_ps(ccol6.add(8)), mt16);
|
||||||
|
// // mt17 = fmadd_ps(beta, _mm256_loadu_ps(ccol7.add(8)), mt17);
|
||||||
|
// }
|
||||||
|
|
||||||
_mm256_storeu_ps(ccol0, mt00);
|
_mm256_storeu_ps(ccol0, mt00);
|
||||||
_mm256_storeu_ps(ccol1, mt01);
|
_mm256_storeu_ps(ccol1, mt01);
|
||||||
_mm256_storeu_ps(ccol2, mt02);
|
_mm256_storeu_ps(ccol2, mt02);
|
||||||
_mm256_storeu_ps(ccol3, mt03);
|
_mm256_storeu_ps(ccol3, mt03);
|
||||||
|
_mm256_storeu_ps(ccol4, mt04);
|
||||||
|
// _mm256_storeu_ps(ccol5, mt05);
|
||||||
|
// _mm256_storeu_ps(ccol6, mt06);
|
||||||
|
// _mm256_storeu_ps(ccol7, mt07);
|
||||||
|
|
||||||
_mm256_storeu_ps(ccol0.add(8), mt10);
|
_mm256_storeu_ps(ccol0.add(8), mt10);
|
||||||
_mm256_storeu_ps(ccol1.add(8), mt11);
|
_mm256_storeu_ps(ccol1.add(8), mt11);
|
||||||
_mm256_storeu_ps(ccol2.add(8), mt12);
|
_mm256_storeu_ps(ccol2.add(8), mt12);
|
||||||
_mm256_storeu_ps(ccol3.add(8), mt13);
|
_mm256_storeu_ps(ccol3.add(8), mt13);
|
||||||
|
_mm256_storeu_ps(ccol4.add(8), mt14);
|
||||||
|
// _mm256_storeu_ps(ccol5.add(8), mt15);
|
||||||
|
// _mm256_storeu_ps(ccol6.add(8), mt16);
|
||||||
|
// _mm256_storeu_ps(ccol7.add(8), mt17);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup_16x1(
|
pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
pa: *const f32,
|
pa: MutMatrix<f32>,
|
||||||
b: *const f32,
|
b: B,
|
||||||
beta: f32,
|
beta: f32,
|
||||||
c: *mut f32,
|
c: C,
|
||||||
) {
|
) {
|
||||||
let mut mt0 = _mm256_setzero_ps();
|
let mut mt0 = _mm256_setzero_ps();
|
||||||
let mut mt1 = _mm256_setzero_ps();
|
let mut mt1 = _mm256_setzero_ps();
|
||||||
@ -99,76 +193,32 @@ pub(crate) unsafe fn sgemm_sup_16x1(
|
|||||||
let mut b = b;
|
let mut b = b;
|
||||||
|
|
||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
let a0 = _mm256_load_ps(pa);
|
let a0 = _mm256_load_ps(pa.ptr());
|
||||||
let a1 = _mm256_load_ps(pa.add(8));
|
let a1 = _mm256_load_ps(pa.col(8));
|
||||||
|
|
||||||
let b0 = _mm256_broadcast_ss(&*b);
|
let b0 = _mm256_broadcast_ss(&*b.ptr());
|
||||||
|
|
||||||
mt0 = fmadd_ps(a0, b0, mt0);
|
mt0 = fmadd_ps(a0, b0, mt0);
|
||||||
mt1 = fmadd_ps(a1, b0, mt1);
|
mt1 = fmadd_ps(a1, b0, mt1);
|
||||||
|
|
||||||
pa = pa.add(16);
|
pa.shift_col(16);
|
||||||
b = b.add(1);
|
b.inc_col();
|
||||||
}
|
}
|
||||||
|
|
||||||
let alpha = _mm256_broadcast_ss(&alpha);
|
// let alpha = _mm256_broadcast_ss(&alpha);
|
||||||
|
|
||||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
// mt0 = _mm256_mul_ps(alpha, mt0);
|
||||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
// mt1 = _mm256_mul_ps(alpha, mt1);
|
||||||
|
|
||||||
if beta != 0.0 {
|
// if beta != 0.0 {
|
||||||
let beta = _mm256_broadcast_ss(&beta);
|
// let beta = _mm256_broadcast_ss(&beta);
|
||||||
|
|
||||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
// mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
// mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||||
}
|
// }
|
||||||
|
|
||||||
_mm256_storeu_ps(c, mt0);
|
_mm256_storeu_ps(c.ptr_mut(), mt0);
|
||||||
_mm256_storeu_ps(c.add(8), mt1);
|
_mm256_storeu_ps(c.col_mut(8), mt1);
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup0_t(
|
|
||||||
k: usize,
|
|
||||||
alpha: f32,
|
|
||||||
pa: *const f32,
|
|
||||||
b: *const f32,
|
|
||||||
ldb: usize,
|
|
||||||
beta: f32,
|
|
||||||
c: *mut f32,
|
|
||||||
) {
|
|
||||||
let mut mt0 = _mm256_setzero_ps();
|
|
||||||
let mut mt1 = _mm256_setzero_ps();
|
|
||||||
|
|
||||||
let mut pa = pa;
|
|
||||||
let mut b = b;
|
|
||||||
|
|
||||||
for _ in 0..k {
|
|
||||||
let a0 = _mm256_load_ps(pa);
|
|
||||||
let a1 = _mm256_load_ps(pa.add(8));
|
|
||||||
|
|
||||||
let b0 = _mm256_broadcast_ss(&*b);
|
|
||||||
|
|
||||||
mt0 = fmadd_ps(a0, b0, mt0);
|
|
||||||
mt1 = fmadd_ps(a1, b0, mt1);
|
|
||||||
|
|
||||||
pa = pa.add(16);
|
|
||||||
b = b.add(ldb);
|
|
||||||
}
|
|
||||||
|
|
||||||
let alpha = _mm256_broadcast_ss(&alpha);
|
|
||||||
|
|
||||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
|
||||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
|
||||||
|
|
||||||
if beta != 0.0 {
|
|
||||||
let beta = _mm256_broadcast_ss(&beta);
|
|
||||||
|
|
||||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
|
||||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
|
||||||
}
|
|
||||||
|
|
||||||
_mm256_storeu_ps(c, mt0);
|
|
||||||
_mm256_storeu_ps(c.add(8), mt1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||||
|
@ -5,3 +5,85 @@ pub mod l1d;
|
|||||||
pub mod l1s;
|
pub mod l1s;
|
||||||
pub mod l3d;
|
pub mod l3d;
|
||||||
pub mod l3s;
|
pub mod l3s;
|
||||||
|
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||||
|
use crate::dim::*;
|
||||||
|
|
||||||
|
pub struct AvxKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||||
|
|
||||||
|
impl<I> GemmKernelSupNr<f32, A5> for AvxKernel<f32, I>
|
||||||
|
where I: GemmKernelSupNr<f32, A5>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::sup_tr(alpha, a, pb, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernelSupMr<f32, A16> for AvxKernel<f32, I>
|
||||||
|
where I: GemmKernelSupMr<f32, A16>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
self::l3s::sgemm_sup_16x1(pa.stride, alpha, pa, b, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernelSup<f32> for AvxKernel<f32, I>
|
||||||
|
where I: GemmKernelSup<f32>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
k: usize,
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::sup_br(k, alpha, a, b, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
|
||||||
|
where I: GemmKernel<f32, A16, A5>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||||
|
if a.is_transposed() {
|
||||||
|
I::pack_row_a(a, pa, i);
|
||||||
|
} else {
|
||||||
|
self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||||
|
I::pack_row_b(b, pb, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
self::l3s::sgemm_ukr_16x8(pa.stride, alpha, pa, pb, beta, c);
|
||||||
|
}
|
||||||
|
}
|
@ -1,114 +1,91 @@
|
|||||||
pub(crate) unsafe fn sgemm_sup_1x4(
|
use crate::kernel::params::single::NR;
|
||||||
|
use crate::matrix::{MutMatrix, Matrix, MatrixMut};
|
||||||
|
|
||||||
|
|
||||||
|
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
a: *const f32,
|
a: A,
|
||||||
lda: usize,
|
pb: MutMatrix<f32>,
|
||||||
pb: *const f32,
|
|
||||||
beta: f32,
|
beta: f32,
|
||||||
c: *mut f32,
|
c: C,
|
||||||
ldc: usize,
|
|
||||||
) {
|
) {
|
||||||
let mut c0 = 0.0;
|
let mut c0 = 0.0f32;
|
||||||
let mut c1 = 0.0;
|
let mut c1 = 0.0f32;
|
||||||
let mut c2 = 0.0;
|
let mut c2 = 0.0f32;
|
||||||
let mut c3 = 0.0;
|
let mut c3 = 0.0f32;
|
||||||
|
let mut c4 = 0.0f32;
|
||||||
|
// let mut c5 = 0.0f32;
|
||||||
|
// let mut c6 = 0.0f32;
|
||||||
|
// let mut c7 = 0.0f32;
|
||||||
|
|
||||||
let mut a = a;
|
let mut a = a;
|
||||||
let mut pb = pb;
|
let mut pb = pb;
|
||||||
|
|
||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
let a0 = *a;
|
let a0 = *a.ptr();
|
||||||
|
|
||||||
c0 += *pb * a0;
|
c0 += *pb.ptr() * a0;
|
||||||
c1 += *pb.add(1) * a0;
|
c1 += *pb.col(1) * a0;
|
||||||
c2 += *pb.add(2) * a0;
|
c2 += *pb.col(2) * a0;
|
||||||
c3 += *pb.add(3) * a0;
|
c3 += *pb.col(3) * a0;
|
||||||
|
c4 += *pb.col(4) * a0;
|
||||||
|
// c5 += *pb.col(5) * a0;
|
||||||
|
// c6 += *pb.col(6) * a0;
|
||||||
|
// c7 += *pb.col(7) * a0;
|
||||||
|
|
||||||
a = a.add(lda);
|
a.inc_row();
|
||||||
pb = pb.add(4);
|
pb.shift_col(NR);
|
||||||
}
|
}
|
||||||
|
|
||||||
c0 *= alpha;
|
// c0 *= alpha;
|
||||||
c1 *= alpha;
|
// c1 *= alpha;
|
||||||
c2 *= alpha;
|
// c2 *= alpha;
|
||||||
c3 *= alpha;
|
// c3 *= alpha;
|
||||||
|
// c4 *= alpha;
|
||||||
|
// // c5 *= alpha;
|
||||||
|
// // c6 *= alpha;
|
||||||
|
// // c7 *= alpha;
|
||||||
|
|
||||||
|
let ccol0 = c.ptr_mut();
|
||||||
|
let ccol1 = c.row_mut(1);
|
||||||
|
let ccol2 = c.row_mut(2);
|
||||||
|
let ccol3 = c.row_mut(3);
|
||||||
|
let ccol4 = c.row_mut(4);
|
||||||
|
// let ccol5 = c.row_mut(5);
|
||||||
|
// let ccol6 = c.row_mut(6);
|
||||||
|
// let ccol7 = c.add(ldc * 7);
|
||||||
|
|
||||||
let ccol0 = c;
|
// if beta != 0.0 {
|
||||||
let ccol1 = c.add(ldc);
|
// c0 += beta * *ccol0;
|
||||||
let ccol2 = c.add(ldc * 2);
|
// c1 += beta * *ccol1;
|
||||||
let ccol3 = c.add(ldc * 3);
|
// c2 += beta * *ccol2;
|
||||||
|
// c3 += beta * *ccol3;
|
||||||
if beta != 0.0 {
|
// c4 += beta * *ccol4;
|
||||||
c0 += beta * *ccol0;
|
// // c5 += beta * *ccol5;
|
||||||
c1 += beta * *ccol1;
|
// // c6 += beta * *ccol6;
|
||||||
c2 += beta * *ccol2;
|
// // c7 += beta * *ccol7;
|
||||||
c3 += beta * *ccol3;
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
*ccol0 = c0;
|
*ccol0 = c0;
|
||||||
*ccol1 = c1;
|
*ccol1 = c1;
|
||||||
*ccol2 = c2;
|
*ccol2 = c2;
|
||||||
*ccol3 = c3;
|
*ccol3 = c3;
|
||||||
|
*ccol4 = c4;
|
||||||
|
// *ccol5 = c5;
|
||||||
|
// *ccol6 = c6;
|
||||||
|
// *ccol7 = c7;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_sup1_t(
|
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||||
k: usize,
|
|
||||||
alpha: f32,
|
|
||||||
a: *const f32,
|
|
||||||
lda: usize,
|
|
||||||
pb: *const f32,
|
|
||||||
beta: f32,
|
|
||||||
c: *mut f32,
|
|
||||||
ldc: usize,
|
|
||||||
) {
|
|
||||||
let mut c0 = 0.0;
|
|
||||||
let mut c1 = 0.0;
|
|
||||||
let mut c2 = 0.0;
|
|
||||||
let mut c3 = 0.0;
|
|
||||||
|
|
||||||
let mut a = a;
|
|
||||||
let mut pb = pb;
|
|
||||||
|
|
||||||
for _ in 0..k {
|
|
||||||
let a0 = *a;
|
|
||||||
|
|
||||||
c0 += *pb * a0;
|
|
||||||
c1 += *pb.add(1) * a0;
|
|
||||||
c2 += *pb.add(2) * a0;
|
|
||||||
c3 += *pb.add(3) * a0;
|
|
||||||
|
|
||||||
a = a.add(1);
|
|
||||||
pb = pb.add(4);
|
|
||||||
}
|
|
||||||
|
|
||||||
c0 *= alpha;
|
|
||||||
c1 *= alpha;
|
|
||||||
c2 *= alpha;
|
|
||||||
c3 *= alpha;
|
|
||||||
|
|
||||||
let ccol0 = c;
|
|
||||||
let ccol1 = c.add(ldc);
|
|
||||||
let ccol2 = c.add(ldc * 2);
|
|
||||||
let ccol3 = c.add(ldc * 3);
|
|
||||||
|
|
||||||
if beta != 0.0 {
|
|
||||||
c0 += beta * *ccol0;
|
|
||||||
c1 += beta * *ccol1;
|
|
||||||
c2 += beta * *ccol2;
|
|
||||||
c3 += beta * *ccol3;
|
|
||||||
}
|
|
||||||
|
|
||||||
*ccol0 = c0;
|
|
||||||
*ccol1 = c1;
|
|
||||||
*ccol2 = c2;
|
|
||||||
*ccol3 = c3;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
|
||||||
let mut bcol0 = b;
|
let mut bcol0 = b;
|
||||||
let mut bcol1 = b.add(ldb);
|
let mut bcol1 = b.add(ldb);
|
||||||
let mut bcol2 = b.add(ldb * 2);
|
let mut bcol2 = b.add(ldb * 2);
|
||||||
let mut bcol3 = b.add(ldb * 3);
|
let mut bcol3 = b.add(ldb * 3);
|
||||||
|
let mut bcol4 = b.add(ldb * 4);
|
||||||
|
// let mut bcol5 = b.add(ldb * 5);
|
||||||
|
// let mut bcol6 = b.add(ldb * 6);
|
||||||
|
// let mut bcol7 = b.row(7);
|
||||||
|
|
||||||
let mut pb = pb;
|
let mut pb = pb;
|
||||||
|
|
||||||
@ -117,27 +94,24 @@ pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f
|
|||||||
*pb.add(1) = *bcol1;
|
*pb.add(1) = *bcol1;
|
||||||
*pb.add(2) = *bcol2;
|
*pb.add(2) = *bcol2;
|
||||||
*pb.add(3) = *bcol3;
|
*pb.add(3) = *bcol3;
|
||||||
|
*pb.add(4) = *bcol4;
|
||||||
|
// *pb.add(5) = *bcol5;
|
||||||
|
// *pb.add(6) = *bcol6;
|
||||||
|
// *pb.col(7) = *bcol7;
|
||||||
|
|
||||||
bcol0 = bcol0.add(1);
|
bcol0 = bcol0.add(1);
|
||||||
bcol1 = bcol1.add(1);
|
bcol1 = bcol1.add(1);
|
||||||
bcol2 = bcol2.add(1);
|
bcol2 = bcol2.add(1);
|
||||||
bcol3 = bcol3.add(1);
|
bcol3 = bcol3.add(1);
|
||||||
pb = pb.add(4);
|
bcol4 = bcol4.add(1);
|
||||||
|
// bcol5 = bcol5.add(1);
|
||||||
|
// bcol6 = bcol6.add(1);
|
||||||
|
// bcol7 = bcol7.add(1);
|
||||||
|
|
||||||
|
pb = pb.add(NR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
|
||||||
// use crate::kernel::params::single::MR;
|
|
||||||
// let mut a = a;
|
|
||||||
// let mut pa = pa;
|
|
||||||
|
|
||||||
// for p in 0..k {
|
|
||||||
// for j in 0..MR {
|
|
||||||
// *pa.add(j * k + p) = *a.add(p * lda + j);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||||
use crate::kernel::params::single::NR;
|
use crate::kernel::params::single::NR;
|
||||||
let mut b = b;
|
let mut b = b;
|
||||||
|
@ -1,2 +1,98 @@
|
|||||||
pub mod l3d;
|
pub mod l3d;
|
||||||
pub mod l3s;
|
pub mod l3s;
|
||||||
|
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||||
|
use crate::dim::*;
|
||||||
|
|
||||||
|
pub struct GenericKernel;
|
||||||
|
|
||||||
|
impl GemmKernelSupNr<f32, A5> for GenericKernel {
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
self::l3s::sgemm_sup_1x8(
|
||||||
|
pb.stride,
|
||||||
|
alpha,
|
||||||
|
a,
|
||||||
|
pb,
|
||||||
|
beta,
|
||||||
|
c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GemmKernelSupMr<f32, A16> for GenericKernel {
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GemmKernelSup<f32> for GenericKernel {
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
k: usize,
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
let mut elem = 0.0;
|
||||||
|
|
||||||
|
for p in 0..k {
|
||||||
|
elem += *a.row(p) * *b.col(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
elem *= alpha;
|
||||||
|
|
||||||
|
if beta != 0.0 {
|
||||||
|
elem += beta * *c.ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
*c.ptr_mut() = elem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GemmKernel<f32, A16, A5> for GenericKernel {
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||||
|
if a.is_transposed() {
|
||||||
|
self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||||
|
} else {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||||
|
if b.is_transposed() {
|
||||||
|
self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||||
|
} else {
|
||||||
|
self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
@ -2,53 +2,20 @@
|
|||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
any(target_arch = "x86_64", target_arch = "x86"),
|
||||||
target_feature = "avx"
|
target_feature = "avx"
|
||||||
))]
|
))]
|
||||||
mod avx;
|
pub mod avx;
|
||||||
|
pub mod generic;
|
||||||
#[cfg(all(
|
pub mod sse;
|
||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
|
||||||
target_feature = "avx"
|
|
||||||
))]
|
|
||||||
pub use avx::{l1d::*, l1s::*};
|
|
||||||
|
|
||||||
#[cfg(all(
|
|
||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
|
||||||
target_feature = "avx"
|
|
||||||
))]
|
|
||||||
pub(crate) use avx::l3s::{
|
|
||||||
sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(all(
|
|
||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
|
||||||
target_feature = "avx"
|
|
||||||
))]
|
|
||||||
pub(crate) use avx::l3d::{
|
|
||||||
dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(all(
|
|
||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
|
||||||
target_feature = "avx"
|
|
||||||
))]
|
|
||||||
pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t};
|
|
||||||
|
|
||||||
#[cfg(all(
|
|
||||||
any(target_arch = "x86_64", target_arch = "x86"),
|
|
||||||
target_feature = "avx"
|
|
||||||
))]
|
|
||||||
pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1};
|
|
||||||
|
|
||||||
|
|
||||||
mod generic;
|
|
||||||
|
|
||||||
|
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix};
|
||||||
|
use crate::dim::Dim;
|
||||||
|
|
||||||
pub mod params {
|
pub mod params {
|
||||||
pub mod single {
|
pub mod single {
|
||||||
pub const MC: usize = 256;
|
pub const MC: usize = 128;
|
||||||
pub const KC: usize = 128;
|
pub const KC: usize = 256;
|
||||||
pub const NC: usize = 1024;
|
pub const NC: usize = 8*1024;
|
||||||
pub const MR: usize = 16;
|
pub const MR: usize = 16;
|
||||||
pub const NR: usize = 4;
|
pub const NR: usize = 5;
|
||||||
}
|
}
|
||||||
pub mod double {
|
pub mod double {
|
||||||
pub const MC: usize = 256;
|
pub const MC: usize = 256;
|
||||||
@ -57,4 +24,61 @@ pub mod params {
|
|||||||
pub const MR: usize = 8;
|
pub const MR: usize = 8;
|
||||||
pub const NR: usize = 4;
|
pub const NR: usize = 4;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// +----------------------+
|
||||||
|
// |TL : : : |TR|
|
||||||
|
// | : : : | |
|
||||||
|
// + - - - - - - - - - ---+
|
||||||
|
// | : : : | |
|
||||||
|
// | : : : | |
|
||||||
|
// + - - - - - - - - - ---+
|
||||||
|
// | : : : | |
|
||||||
|
// | : : : | |
|
||||||
|
// +----------------------+
|
||||||
|
// |BL | | | |BR|
|
||||||
|
// +----------------------+
|
||||||
|
|
||||||
|
pub trait GemmKernelSup<F: Number> {
|
||||||
|
unsafe fn sup_br<A: Matrix<F>, B: Matrix<F>, C: MatrixMut<F>>(
|
||||||
|
k: usize,
|
||||||
|
alpha: F,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: F,
|
||||||
|
c: C,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait GemmKernelSupMr<F: Number, MR: Dim> {
|
||||||
|
unsafe fn sup_bl<B: Matrix<F>, C: MatrixMut<F>>(
|
||||||
|
alpha: F,
|
||||||
|
pa: MutMatrix<F>,
|
||||||
|
b: B,
|
||||||
|
beta: F,
|
||||||
|
c: C,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait GemmKernelSupNr<F: Number, NR: Dim> {
|
||||||
|
unsafe fn sup_tr<A: Matrix<F>, C: MatrixMut<F>>(
|
||||||
|
alpha: F,
|
||||||
|
a: A,
|
||||||
|
pb: MutMatrix<F>,
|
||||||
|
beta: F,
|
||||||
|
c: C,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
|
||||||
|
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>, i: usize);
|
||||||
|
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>, j: usize);
|
||||||
|
|
||||||
|
unsafe fn main_tl<C: MatrixMut<F>>(
|
||||||
|
alpha: F,
|
||||||
|
pa: MutMatrix<F>,
|
||||||
|
pb: MutMatrix<F>,
|
||||||
|
beta: F,
|
||||||
|
c: C,
|
||||||
|
);
|
||||||
}
|
}
|
47
src/kernel/sse/fma.rs
Normal file
47
src/kernel/sse/fma.rs
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
#[cfg(target_feature = "fma")]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||||
|
_mm_fmadd_ps(a, b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_feature = "fma"))]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||||
|
_mm_add_ps(_mm_mul_ps(a, b), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "fma")]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||||
|
_mm_fmsub_ps(a, b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_feature = "fma"))]
|
||||||
|
pub unsafe fn fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
|
||||||
|
_mm_sub_ps(_mm_mul_ps(a, b), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "fma")]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||||
|
_mm_fmadd_pd(a, b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_feature = "fma"))]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||||
|
_mm_add_pd(_mm_mul_pd(a, b), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "fma")]
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||||
|
_mm_fmsub_pd(a, b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_feature = "fma"))]
|
||||||
|
pub unsafe fn fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
|
||||||
|
_mm_sub_pd(_mm_mul_pd(a, b), c)
|
||||||
|
}
|
23
src/kernel/sse/hsum.rs
Normal file
23
src/kernel/sse/hsum.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn hsum_ps(v: __m256) -> f32 {
|
||||||
|
let qhigh = _mm256_extractf128_ps(v, 1);
|
||||||
|
let qlow = _mm256_castps256_ps128(v);
|
||||||
|
let qsum = _mm_add_ps(qhigh, qlow);
|
||||||
|
let dhigh = _mm_movehl_ps(qsum, qsum);
|
||||||
|
let dlow = qsum;
|
||||||
|
let dsum = _mm_add_ps(dhigh, dlow);
|
||||||
|
let high = _mm_shuffle_ps(dsum, dsum, 1);
|
||||||
|
let low = dsum;
|
||||||
|
_mm_cvtss_f32(_mm_add_ss(high, low))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn hsum_pd(v: __m256d) -> f64 {
|
||||||
|
let vhigh = _mm256_extractf128_pd(v, 1);
|
||||||
|
let vlow = _mm256_castpd256_pd128(v);
|
||||||
|
let vsum = _mm_add_pd(vlow, vhigh);
|
||||||
|
let h64 = _mm_unpackhi_pd(vsum, vsum);
|
||||||
|
_mm_cvtsd_f64(_mm_add_sd(vsum, h64))
|
||||||
|
}
|
5
src/kernel/sse/intrinsics.rs
Normal file
5
src/kernel/sse/intrinsics.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#[cfg(target_arch = "x86")]
|
||||||
|
pub use std::arch::x86::*;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
pub use std::arch::x86_64::*;
|
396
src/kernel/sse/l1d.rs
Normal file
396
src/kernel/sse/l1d.rs
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
use super::fma::{fmadd_pd, fmsub_pd};
|
||||||
|
use super::hsum::hsum_pd;
|
||||||
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
pub unsafe fn drot(
|
||||||
|
n: usize,
|
||||||
|
mut x: *mut f64,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *mut f64,
|
||||||
|
incy: usize,
|
||||||
|
c: f64,
|
||||||
|
s: f64,
|
||||||
|
) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let c0 = _mm256_broadcast_sd(&c);
|
||||||
|
let s0 = _mm256_broadcast_sd(&s);
|
||||||
|
|
||||||
|
for _ in 0..n / 16 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let y0 = _mm256_loadu_pd(y);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let y1 = _mm256_loadu_pd(y.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let y2 = _mm256_loadu_pd(y.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let y3 = _mm256_loadu_pd(y.add(12));
|
||||||
|
|
||||||
|
_mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0)));
|
||||||
|
_mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0)));
|
||||||
|
_mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1)));
|
||||||
|
_mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1)));
|
||||||
|
_mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2)));
|
||||||
|
_mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2)));
|
||||||
|
_mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3)));
|
||||||
|
_mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3)));
|
||||||
|
|
||||||
|
x = x.add(16);
|
||||||
|
y = y.add(16);
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..n % 16 {
|
||||||
|
let x0 = *x;
|
||||||
|
let y0 = *y;
|
||||||
|
|
||||||
|
*x = c * x0 + s * y0;
|
||||||
|
*y = c * y0 - s * x0;
|
||||||
|
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
let y0 = *y;
|
||||||
|
|
||||||
|
*x = c * x0 + s * y0;
|
||||||
|
*y = c * y0 - s * x0;
|
||||||
|
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
for _ in 0..n / 16 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let y0 = _mm256_loadu_pd(y);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let y1 = _mm256_loadu_pd(y.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let y2 = _mm256_loadu_pd(y.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let y3 = _mm256_loadu_pd(y.add(12));
|
||||||
|
|
||||||
|
_mm256_storeu_pd(x, y0);
|
||||||
|
_mm256_storeu_pd(y, x0);
|
||||||
|
_mm256_storeu_pd(x.add(4), y1);
|
||||||
|
_mm256_storeu_pd(y.add(4), x1);
|
||||||
|
_mm256_storeu_pd(x.add(8), y2);
|
||||||
|
_mm256_storeu_pd(y.add(8), x2);
|
||||||
|
_mm256_storeu_pd(x.add(12), y3);
|
||||||
|
_mm256_storeu_pd(y.add(12), x3);
|
||||||
|
|
||||||
|
x = x.add(16);
|
||||||
|
y = y.add(16);
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..n % 16 {
|
||||||
|
let x0 = *x;
|
||||||
|
|
||||||
|
*x = *y;
|
||||||
|
*y = x0;
|
||||||
|
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
|
||||||
|
*x = *y;
|
||||||
|
*y = x0;
|
||||||
|
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) {
|
||||||
|
if incx == 1 {
|
||||||
|
let a0 = _mm256_broadcast_sd(&a);
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let mut x0 = _mm256_loadu_pd(x);
|
||||||
|
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||||
|
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||||
|
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||||
|
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||||
|
|
||||||
|
x0 = _mm256_mul_pd(a0, x0);
|
||||||
|
x1 = _mm256_mul_pd(a0, x1);
|
||||||
|
x2 = _mm256_mul_pd(a0, x2);
|
||||||
|
x3 = _mm256_mul_pd(a0, x3);
|
||||||
|
x4 = _mm256_mul_pd(a0, x4);
|
||||||
|
x5 = _mm256_mul_pd(a0, x5);
|
||||||
|
x6 = _mm256_mul_pd(a0, x6);
|
||||||
|
x7 = _mm256_mul_pd(a0, x7);
|
||||||
|
|
||||||
|
_mm256_storeu_pd(x, x0);
|
||||||
|
_mm256_storeu_pd(x.add(4), x1);
|
||||||
|
_mm256_storeu_pd(x.add(8), x2);
|
||||||
|
_mm256_storeu_pd(x.add(12), x3);
|
||||||
|
_mm256_storeu_pd(x.add(16), x4);
|
||||||
|
_mm256_storeu_pd(x.add(20), x5);
|
||||||
|
_mm256_storeu_pd(x.add(24), x6);
|
||||||
|
_mm256_storeu_pd(x.add(28), x7);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
*x *= a;
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*x *= a;
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let x4 = _mm256_loadu_pd(x.add(16));
|
||||||
|
let x5 = _mm256_loadu_pd(x.add(20));
|
||||||
|
let x6 = _mm256_loadu_pd(x.add(24));
|
||||||
|
let x7 = _mm256_loadu_pd(x.add(28));
|
||||||
|
|
||||||
|
_mm256_storeu_pd(y, x0);
|
||||||
|
_mm256_storeu_pd(y.add(4), x1);
|
||||||
|
_mm256_storeu_pd(y.add(8), x2);
|
||||||
|
_mm256_storeu_pd(y.add(12), x3);
|
||||||
|
_mm256_storeu_pd(y.add(16), x4);
|
||||||
|
_mm256_storeu_pd(y.add(20), x5);
|
||||||
|
_mm256_storeu_pd(y.add(24), x6);
|
||||||
|
_mm256_storeu_pd(y.add(28), x7);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
y = y.add(32);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
*y = *x;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*y = *x;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn daxpy(
|
||||||
|
n: usize,
|
||||||
|
a: f64,
|
||||||
|
mut x: *const f64,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *mut f64,
|
||||||
|
incy: usize,
|
||||||
|
) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let a0 = _mm256_broadcast_sd(&a);
|
||||||
|
for _ in 0..n / 16 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let y0 = _mm256_loadu_pd(y);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let y1 = _mm256_loadu_pd(y.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let y2 = _mm256_loadu_pd(y.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let y3 = _mm256_loadu_pd(y.add(12));
|
||||||
|
|
||||||
|
_mm256_storeu_pd(y, fmadd_pd(a0, x0, y0));
|
||||||
|
_mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1));
|
||||||
|
_mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2));
|
||||||
|
_mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3));
|
||||||
|
|
||||||
|
x = x.add(16);
|
||||||
|
y = y.add(16);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 16 {
|
||||||
|
*y += a * *x;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*y += a * *x;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn ddot(
|
||||||
|
n: usize,
|
||||||
|
mut x: *const f64,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *const f64,
|
||||||
|
incy: usize,
|
||||||
|
) -> f64 {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let mut acc0 = _mm256_setzero_pd();
|
||||||
|
let mut acc1 = _mm256_setzero_pd();
|
||||||
|
let mut acc2 = _mm256_setzero_pd();
|
||||||
|
let mut acc3 = _mm256_setzero_pd();
|
||||||
|
for _ in 0..n / 16 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let y0 = _mm256_loadu_pd(y);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let y1 = _mm256_loadu_pd(y.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let y2 = _mm256_loadu_pd(y.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let y3 = _mm256_loadu_pd(y.add(12));
|
||||||
|
|
||||||
|
acc0 = fmadd_pd(x0, y0, acc0);
|
||||||
|
acc1 = fmadd_pd(x1, y1, acc1);
|
||||||
|
acc2 = fmadd_pd(x2, y2, acc2);
|
||||||
|
acc3 = fmadd_pd(x3, y3, acc3);
|
||||||
|
|
||||||
|
x = x.add(16);
|
||||||
|
y = y.add(16);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_pd(acc2, acc3);
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc2);
|
||||||
|
|
||||||
|
let mut acc = hsum_pd(acc0);
|
||||||
|
for _ in 0..n % 16 {
|
||||||
|
acc += *x * *y;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
acc += *x * *y;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||||
|
if incx == 1 {
|
||||||
|
let mut acc0 = _mm256_setzero_pd();
|
||||||
|
let mut acc1 = _mm256_setzero_pd();
|
||||||
|
let mut acc2 = _mm256_setzero_pd();
|
||||||
|
let mut acc3 = _mm256_setzero_pd();
|
||||||
|
for _ in 0..n / 16 {
|
||||||
|
let x0 = _mm256_loadu_pd(x);
|
||||||
|
let x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
|
||||||
|
acc0 = fmadd_pd(x0, x0, acc0);
|
||||||
|
acc1 = fmadd_pd(x1, x1, acc1);
|
||||||
|
acc2 = fmadd_pd(x2, x2, acc2);
|
||||||
|
acc3 = fmadd_pd(x3, x3, acc3);
|
||||||
|
|
||||||
|
x = x.add(16);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_pd(acc2, acc3);
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc2);
|
||||||
|
|
||||||
|
let mut acc = hsum_pd(acc0);
|
||||||
|
for _ in 0..n % 16 {
|
||||||
|
let x0 = *x;
|
||||||
|
acc += x0 * x0;
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
acc.sqrt()
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
acc += x0 * x0;
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
acc.sqrt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||||
|
if incx == 1 {
|
||||||
|
let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF));
|
||||||
|
|
||||||
|
let mut acc0 = _mm256_setzero_pd();
|
||||||
|
let mut acc1 = _mm256_setzero_pd();
|
||||||
|
let mut acc2 = _mm256_setzero_pd();
|
||||||
|
let mut acc3 = _mm256_setzero_pd();
|
||||||
|
let mut acc4 = _mm256_setzero_pd();
|
||||||
|
let mut acc5 = _mm256_setzero_pd();
|
||||||
|
let mut acc6 = _mm256_setzero_pd();
|
||||||
|
let mut acc7 = _mm256_setzero_pd();
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let mut x0 = _mm256_loadu_pd(x);
|
||||||
|
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||||
|
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||||
|
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||||
|
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||||
|
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||||
|
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||||
|
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||||
|
|
||||||
|
x0 = _mm256_and_pd(mask, x0);
|
||||||
|
x1 = _mm256_and_pd(mask, x1);
|
||||||
|
x2 = _mm256_and_pd(mask, x2);
|
||||||
|
x3 = _mm256_and_pd(mask, x3);
|
||||||
|
x4 = _mm256_and_pd(mask, x4);
|
||||||
|
x5 = _mm256_and_pd(mask, x5);
|
||||||
|
x6 = _mm256_and_pd(mask, x6);
|
||||||
|
x7 = _mm256_and_pd(mask, x7);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_pd(acc0, x0);
|
||||||
|
acc1 = _mm256_add_pd(acc1, x1);
|
||||||
|
acc2 = _mm256_add_pd(acc2, x2);
|
||||||
|
acc3 = _mm256_add_pd(acc3, x3);
|
||||||
|
acc4 = _mm256_add_pd(acc4, x4);
|
||||||
|
acc5 = _mm256_add_pd(acc5, x5);
|
||||||
|
acc6 = _mm256_add_pd(acc6, x6);
|
||||||
|
acc7 = _mm256_add_pd(acc7, x7);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_pd(acc2, acc3);
|
||||||
|
acc4 = _mm256_add_pd(acc4, acc5);
|
||||||
|
acc6 = _mm256_add_pd(acc6, acc7);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc2);
|
||||||
|
acc4 = _mm256_add_pd(acc4, acc6);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_pd(acc0, acc4);
|
||||||
|
|
||||||
|
let mut acc = hsum_pd(acc0);
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
acc += (*x).abs();
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
acc += (*x).abs();
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
}
|
||||||
|
}
|
396
src/kernel/sse/l1s.rs
Normal file
396
src/kernel/sse/l1s.rs
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
use super::fma::{fmadd_ps, fmsub_ps};
|
||||||
|
use super::hsum::hsum_ps;
|
||||||
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
pub unsafe fn srot(
|
||||||
|
n: usize,
|
||||||
|
mut x: *mut f32,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *mut f32,
|
||||||
|
incy: usize,
|
||||||
|
c: f32,
|
||||||
|
s: f32,
|
||||||
|
) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let c0 = _mm256_broadcast_ss(&c);
|
||||||
|
let s0 = _mm256_broadcast_ss(&s);
|
||||||
|
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let y1 = _mm256_loadu_ps(y.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let y2 = _mm256_loadu_ps(y.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let y3 = _mm256_loadu_ps(y.add(24));
|
||||||
|
|
||||||
|
_mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0)));
|
||||||
|
_mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0)));
|
||||||
|
_mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1)));
|
||||||
|
_mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1)));
|
||||||
|
_mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2)));
|
||||||
|
_mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2)));
|
||||||
|
_mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3)));
|
||||||
|
_mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3)));
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
y = y.add(32);
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
let x0 = *x;
|
||||||
|
let y0 = *y;
|
||||||
|
|
||||||
|
*x = c * x0 + s * y0;
|
||||||
|
*y = c * y0 - s * x0;
|
||||||
|
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
let y0 = *y;
|
||||||
|
|
||||||
|
*x = c * x0 + s * y0;
|
||||||
|
*y = c * y0 - s * x0;
|
||||||
|
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let y1 = _mm256_loadu_ps(y.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let y2 = _mm256_loadu_ps(y.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let y3 = _mm256_loadu_ps(y.add(24));
|
||||||
|
|
||||||
|
_mm256_storeu_ps(x, y0);
|
||||||
|
_mm256_storeu_ps(y, x0);
|
||||||
|
_mm256_storeu_ps(x.add(8), y1);
|
||||||
|
_mm256_storeu_ps(y.add(8), x1);
|
||||||
|
_mm256_storeu_ps(x.add(16), y2);
|
||||||
|
_mm256_storeu_ps(y.add(16), x2);
|
||||||
|
_mm256_storeu_ps(x.add(24), y3);
|
||||||
|
_mm256_storeu_ps(y.add(24), x3);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
y = y.add(32);
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
let x0 = *x;
|
||||||
|
|
||||||
|
*x = *y;
|
||||||
|
*y = x0;
|
||||||
|
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
|
||||||
|
*x = *y;
|
||||||
|
*y = x0;
|
||||||
|
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) {
|
||||||
|
if incx == 1 {
|
||||||
|
let a0 = _mm256_broadcast_ss(&a);
|
||||||
|
for _ in 0..n / 64 {
|
||||||
|
let mut x0 = _mm256_loadu_ps(x);
|
||||||
|
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||||
|
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||||
|
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||||
|
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||||
|
|
||||||
|
x0 = _mm256_mul_ps(a0, x0);
|
||||||
|
x1 = _mm256_mul_ps(a0, x1);
|
||||||
|
x2 = _mm256_mul_ps(a0, x2);
|
||||||
|
x3 = _mm256_mul_ps(a0, x3);
|
||||||
|
x4 = _mm256_mul_ps(a0, x4);
|
||||||
|
x5 = _mm256_mul_ps(a0, x5);
|
||||||
|
x6 = _mm256_mul_ps(a0, x6);
|
||||||
|
x7 = _mm256_mul_ps(a0, x7);
|
||||||
|
|
||||||
|
_mm256_storeu_ps(x, x0);
|
||||||
|
_mm256_storeu_ps(x.add(8), x1);
|
||||||
|
_mm256_storeu_ps(x.add(16), x2);
|
||||||
|
_mm256_storeu_ps(x.add(24), x3);
|
||||||
|
_mm256_storeu_ps(x.add(32), x4);
|
||||||
|
_mm256_storeu_ps(x.add(40), x5);
|
||||||
|
_mm256_storeu_ps(x.add(48), x6);
|
||||||
|
_mm256_storeu_ps(x.add(56), x7);
|
||||||
|
|
||||||
|
x = x.add(64);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 64 {
|
||||||
|
*x *= a;
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*x *= a;
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
for _ in 0..n / 64 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let x4 = _mm256_loadu_ps(x.add(32));
|
||||||
|
let x5 = _mm256_loadu_ps(x.add(40));
|
||||||
|
let x6 = _mm256_loadu_ps(x.add(48));
|
||||||
|
let x7 = _mm256_loadu_ps(x.add(56));
|
||||||
|
|
||||||
|
_mm256_storeu_ps(y, x0);
|
||||||
|
_mm256_storeu_ps(y.add(8), x1);
|
||||||
|
_mm256_storeu_ps(y.add(16), x2);
|
||||||
|
_mm256_storeu_ps(y.add(24), x3);
|
||||||
|
_mm256_storeu_ps(y.add(32), x4);
|
||||||
|
_mm256_storeu_ps(y.add(40), x5);
|
||||||
|
_mm256_storeu_ps(y.add(48), x6);
|
||||||
|
_mm256_storeu_ps(y.add(56), x7);
|
||||||
|
|
||||||
|
x = x.add(64);
|
||||||
|
y = y.add(64);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 64 {
|
||||||
|
*y = *x;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*y = *x;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn saxpy(
|
||||||
|
n: usize,
|
||||||
|
a: f32,
|
||||||
|
mut x: *const f32,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *mut f32,
|
||||||
|
incy: usize,
|
||||||
|
) {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let a0 = _mm256_broadcast_ss(&a);
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let y1 = _mm256_loadu_ps(y.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let y2 = _mm256_loadu_ps(y.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let y3 = _mm256_loadu_ps(y.add(24));
|
||||||
|
|
||||||
|
_mm256_storeu_ps(y, fmadd_ps(a0, x0, y0));
|
||||||
|
_mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1));
|
||||||
|
_mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2));
|
||||||
|
_mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3));
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
y = y.add(32);
|
||||||
|
}
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
*y += a * *x;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _ in 0..n {
|
||||||
|
*y += a * *x;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn sdot(
|
||||||
|
n: usize,
|
||||||
|
mut x: *const f32,
|
||||||
|
incx: usize,
|
||||||
|
mut y: *const f32,
|
||||||
|
incy: usize,
|
||||||
|
) -> f32 {
|
||||||
|
if incx == 1 && incy == 1 {
|
||||||
|
let mut acc0 = _mm256_setzero_ps();
|
||||||
|
let mut acc1 = _mm256_setzero_ps();
|
||||||
|
let mut acc2 = _mm256_setzero_ps();
|
||||||
|
let mut acc3 = _mm256_setzero_ps();
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let y0 = _mm256_loadu_ps(y);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let y1 = _mm256_loadu_ps(y.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let y2 = _mm256_loadu_ps(y.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let y3 = _mm256_loadu_ps(y.add(24));
|
||||||
|
|
||||||
|
acc0 = fmadd_ps(x0, y0, acc0);
|
||||||
|
acc1 = fmadd_ps(x1, y1, acc1);
|
||||||
|
acc2 = fmadd_ps(x2, y2, acc2);
|
||||||
|
acc3 = fmadd_ps(x3, y3, acc3);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
y = y.add(32);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_ps(acc2, acc3);
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc2);
|
||||||
|
|
||||||
|
let mut acc = hsum_ps(acc0);
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
acc += *x * *y;
|
||||||
|
x = x.add(1);
|
||||||
|
y = y.add(1);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
acc += *x * *y;
|
||||||
|
x = x.add(incx);
|
||||||
|
y = y.add(incy);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||||
|
if incx == 1 {
|
||||||
|
let mut acc0 = _mm256_setzero_ps();
|
||||||
|
let mut acc1 = _mm256_setzero_ps();
|
||||||
|
let mut acc2 = _mm256_setzero_ps();
|
||||||
|
let mut acc3 = _mm256_setzero_ps();
|
||||||
|
for _ in 0..n / 32 {
|
||||||
|
let x0 = _mm256_loadu_ps(x);
|
||||||
|
let x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
|
||||||
|
acc0 = fmadd_ps(x0, x0, acc0);
|
||||||
|
acc1 = fmadd_ps(x1, x1, acc1);
|
||||||
|
acc2 = fmadd_ps(x2, x2, acc2);
|
||||||
|
acc3 = fmadd_ps(x3, x3, acc3);
|
||||||
|
|
||||||
|
x = x.add(32);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_ps(acc2, acc3);
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc2);
|
||||||
|
|
||||||
|
let mut acc = hsum_ps(acc0);
|
||||||
|
for _ in 0..n % 32 {
|
||||||
|
let x0 = *x;
|
||||||
|
acc += x0 * x0;
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
acc.sqrt()
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
let x0 = *x;
|
||||||
|
acc += x0 * x0;
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
acc.sqrt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||||
|
if incx == 1 {
|
||||||
|
let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF));
|
||||||
|
|
||||||
|
let mut acc0 = _mm256_setzero_ps();
|
||||||
|
let mut acc1 = _mm256_setzero_ps();
|
||||||
|
let mut acc2 = _mm256_setzero_ps();
|
||||||
|
let mut acc3 = _mm256_setzero_ps();
|
||||||
|
let mut acc4 = _mm256_setzero_ps();
|
||||||
|
let mut acc5 = _mm256_setzero_ps();
|
||||||
|
let mut acc6 = _mm256_setzero_ps();
|
||||||
|
let mut acc7 = _mm256_setzero_ps();
|
||||||
|
for _ in 0..n / 64 {
|
||||||
|
let mut x0 = _mm256_loadu_ps(x);
|
||||||
|
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||||
|
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||||
|
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||||
|
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||||
|
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||||
|
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||||
|
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||||
|
|
||||||
|
x0 = _mm256_and_ps(mask, x0);
|
||||||
|
x1 = _mm256_and_ps(mask, x1);
|
||||||
|
x2 = _mm256_and_ps(mask, x2);
|
||||||
|
x3 = _mm256_and_ps(mask, x3);
|
||||||
|
x4 = _mm256_and_ps(mask, x4);
|
||||||
|
x5 = _mm256_and_ps(mask, x5);
|
||||||
|
x6 = _mm256_and_ps(mask, x6);
|
||||||
|
x7 = _mm256_and_ps(mask, x7);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_ps(acc0, x0);
|
||||||
|
acc1 = _mm256_add_ps(acc1, x1);
|
||||||
|
acc2 = _mm256_add_ps(acc2, x2);
|
||||||
|
acc3 = _mm256_add_ps(acc3, x3);
|
||||||
|
acc4 = _mm256_add_ps(acc4, x4);
|
||||||
|
acc5 = _mm256_add_ps(acc5, x5);
|
||||||
|
acc6 = _mm256_add_ps(acc6, x6);
|
||||||
|
acc7 = _mm256_add_ps(acc7, x7);
|
||||||
|
|
||||||
|
x = x.add(64);
|
||||||
|
}
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc1);
|
||||||
|
acc2 = _mm256_add_ps(acc2, acc3);
|
||||||
|
acc4 = _mm256_add_ps(acc4, acc5);
|
||||||
|
acc6 = _mm256_add_ps(acc6, acc7);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc2);
|
||||||
|
acc4 = _mm256_add_ps(acc4, acc6);
|
||||||
|
|
||||||
|
acc0 = _mm256_add_ps(acc0, acc4);
|
||||||
|
|
||||||
|
let mut acc = hsum_ps(acc0);
|
||||||
|
for _ in 0..n % 64 {
|
||||||
|
acc += (*x).abs();
|
||||||
|
x = x.add(1);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for _ in 0..n {
|
||||||
|
acc += (*x).abs();
|
||||||
|
x = x.add(incx);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
}
|
||||||
|
}
|
140
src/kernel/sse/l3d.rs
Normal file
140
src/kernel/sse/l3d.rs
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
use super::fma::fmadd_pd;
|
||||||
|
use super::intrinsics::*;
|
||||||
|
|
||||||
|
pub(crate) unsafe fn dgemm_ukr_8x4(
|
||||||
|
k: usize,
|
||||||
|
alpha: f64,
|
||||||
|
pa: *const f64,
|
||||||
|
pb: *const f64,
|
||||||
|
beta: f64,
|
||||||
|
c: *mut f64,
|
||||||
|
ldc: usize,
|
||||||
|
) {
|
||||||
|
let mut mt00 = _mm256_setzero_pd();
|
||||||
|
let mut mt01 = _mm256_setzero_pd();
|
||||||
|
let mut mt02 = _mm256_setzero_pd();
|
||||||
|
let mut mt03 = _mm256_setzero_pd();
|
||||||
|
let mut mt10 = _mm256_setzero_pd();
|
||||||
|
let mut mt11 = _mm256_setzero_pd();
|
||||||
|
let mut mt12 = _mm256_setzero_pd();
|
||||||
|
let mut mt13 = _mm256_setzero_pd();
|
||||||
|
|
||||||
|
let mut pa = pa;
|
||||||
|
let mut pb = pb;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
let a0 = _mm256_loadu_pd(pa);
|
||||||
|
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||||
|
|
||||||
|
let b0 = _mm256_broadcast_sd(&*pb);
|
||||||
|
let b1 = _mm256_broadcast_sd(&*pb.add(1));
|
||||||
|
let b2 = _mm256_broadcast_sd(&*pb.add(2));
|
||||||
|
let b3 = _mm256_broadcast_sd(&*pb.add(3));
|
||||||
|
|
||||||
|
mt00 = fmadd_pd(a0, b0, mt00);
|
||||||
|
mt01 = fmadd_pd(a0, b1, mt01);
|
||||||
|
mt02 = fmadd_pd(a0, b2, mt02);
|
||||||
|
mt03 = fmadd_pd(a0, b3, mt03);
|
||||||
|
mt10 = fmadd_pd(a1, b0, mt10);
|
||||||
|
mt11 = fmadd_pd(a1, b1, mt11);
|
||||||
|
mt12 = fmadd_pd(a1, b2, mt12);
|
||||||
|
mt13 = fmadd_pd(a1, b3, mt13);
|
||||||
|
|
||||||
|
pa = pa.add(8);
|
||||||
|
pb = pb.add(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
let alpha = _mm256_broadcast_sd(&alpha);
|
||||||
|
|
||||||
|
mt00 = _mm256_mul_pd(alpha, mt00);
|
||||||
|
mt01 = _mm256_mul_pd(alpha, mt01);
|
||||||
|
mt02 = _mm256_mul_pd(alpha, mt02);
|
||||||
|
mt03 = _mm256_mul_pd(alpha, mt03);
|
||||||
|
mt10 = _mm256_mul_pd(alpha, mt10);
|
||||||
|
mt11 = _mm256_mul_pd(alpha, mt11);
|
||||||
|
mt12 = _mm256_mul_pd(alpha, mt12);
|
||||||
|
mt13 = _mm256_mul_pd(alpha, mt13);
|
||||||
|
|
||||||
|
let ccol0 = c;
|
||||||
|
let ccol1 = c.add(ldc);
|
||||||
|
let ccol2 = c.add(ldc * 2);
|
||||||
|
let ccol3 = c.add(ldc * 3);
|
||||||
|
|
||||||
|
if beta != 0.0 {
|
||||||
|
let beta = _mm256_broadcast_sd(&beta);
|
||||||
|
|
||||||
|
mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00);
|
||||||
|
mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01);
|
||||||
|
mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02);
|
||||||
|
mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03);
|
||||||
|
mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10);
|
||||||
|
mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11);
|
||||||
|
mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12);
|
||||||
|
mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13);
|
||||||
|
}
|
||||||
|
|
||||||
|
_mm256_storeu_pd(ccol0, mt00);
|
||||||
|
_mm256_storeu_pd(ccol1, mt01);
|
||||||
|
_mm256_storeu_pd(ccol2, mt02);
|
||||||
|
_mm256_storeu_pd(ccol3, mt03);
|
||||||
|
_mm256_storeu_pd(ccol0.add(4), mt10);
|
||||||
|
_mm256_storeu_pd(ccol1.add(4), mt11);
|
||||||
|
_mm256_storeu_pd(ccol2.add(4), mt12);
|
||||||
|
_mm256_storeu_pd(ccol3.add(4), mt13);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn dgemm_sup_8x1(
|
||||||
|
k: usize,
|
||||||
|
alpha: f64,
|
||||||
|
pa: *const f64,
|
||||||
|
b: *const f64,
|
||||||
|
beta: f64,
|
||||||
|
c: *mut f64,
|
||||||
|
) {
|
||||||
|
let mut mt0 = _mm256_setzero_pd();
|
||||||
|
let mut mt1 = _mm256_setzero_pd();
|
||||||
|
|
||||||
|
let mut pa = pa;
|
||||||
|
let mut b = b;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
let a0 = _mm256_loadu_pd(pa);
|
||||||
|
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||||
|
|
||||||
|
let b0 = _mm256_broadcast_sd(&*b);
|
||||||
|
|
||||||
|
mt0 = fmadd_pd(a0, b0, mt0);
|
||||||
|
mt1 = fmadd_pd(a1, b0, mt1);
|
||||||
|
|
||||||
|
pa = pa.add(8);
|
||||||
|
b = b.add(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let alpha = _mm256_broadcast_sd(&alpha);
|
||||||
|
|
||||||
|
mt0 = _mm256_mul_pd(alpha, mt0);
|
||||||
|
mt1 = _mm256_mul_pd(alpha, mt1);
|
||||||
|
|
||||||
|
if beta != 0.0 {
|
||||||
|
let beta = _mm256_broadcast_sd(&beta);
|
||||||
|
|
||||||
|
mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0);
|
||||||
|
mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1);
|
||||||
|
}
|
||||||
|
|
||||||
|
_mm256_storeu_pd(c, mt0);
|
||||||
|
_mm256_storeu_pd(c.add(4), mt1);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) {
|
||||||
|
let mut a = a;
|
||||||
|
let mut pa = pa;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
_mm256_storeu_pd(pa, _mm256_loadu_pd(a));
|
||||||
|
_mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4)));
|
||||||
|
|
||||||
|
pa = pa.add(8);
|
||||||
|
a = a.add(lda);
|
||||||
|
}
|
||||||
|
}
|
178
src/kernel/sse/l3s.rs
Normal file
178
src/kernel/sse/l3s.rs
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
use super::fma::fmadd_ps;
|
||||||
|
use super::intrinsics::*;
|
||||||
|
use crate::kernel::params::single::NR;
|
||||||
|
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||||
|
|
||||||
|
pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
k: usize,
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
let mut c0_3 = _mm_setzero_ps();
|
||||||
|
let mut c4 = 0.0f32;
|
||||||
|
// let mut c5 = 0.0f32;
|
||||||
|
// let mut c6 = 0.0f32;
|
||||||
|
// let mut c7 = 0.0f32;
|
||||||
|
|
||||||
|
let mut a = a;
|
||||||
|
let mut pb = pb;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
let a0 = *a.ptr();
|
||||||
|
let a0_simd = _mm_broadcast_ss(&*a.ptr());
|
||||||
|
|
||||||
|
c0_3 = fmadd_ps(_mm_loadu_ps(pb.ptr()), a0_simd, c0_3);
|
||||||
|
c4 += *pb.col(4) * a0;
|
||||||
|
// c5 += *pb.add(5) * a0;
|
||||||
|
// c6 += *pb.add(6) * a0;
|
||||||
|
// c7 += *pb.add(7) * a0;
|
||||||
|
|
||||||
|
a.inc_row();
|
||||||
|
pb.shift_col(NR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// c0 *= alpha;
|
||||||
|
// c1 *= alpha;
|
||||||
|
// c2 *= alpha;
|
||||||
|
// c3 *= alpha;
|
||||||
|
// c4 *= alpha;
|
||||||
|
// c5 *= alpha;
|
||||||
|
// c6 *= alpha;
|
||||||
|
// c7 *= alpha;
|
||||||
|
|
||||||
|
let ccol0_3 = c.ptr_mut();
|
||||||
|
let ccol4 = c.row_mut(4);
|
||||||
|
// let ccol5 = c.add(ldc * 5);
|
||||||
|
// let ccol6 = c.add(ldc * 6);
|
||||||
|
// let ccol7 = c.add(ldc * 7);
|
||||||
|
|
||||||
|
// if beta != 0.0 {
|
||||||
|
// c0 += beta * *ccol0;
|
||||||
|
// c1 += beta * *ccol1;
|
||||||
|
// c2 += beta * *ccol2;
|
||||||
|
// c3 += beta * *ccol3;
|
||||||
|
// c4 += beta * *ccol4;
|
||||||
|
// // c5 += beta * *ccol5;
|
||||||
|
// // c6 += beta * *ccol6;
|
||||||
|
// // c7 += beta * *ccol7;
|
||||||
|
// }
|
||||||
|
|
||||||
|
_mm_storeu_ps(ccol0_3, c0_3);
|
||||||
|
*ccol4 = c4;
|
||||||
|
// *ccol5 = c5;
|
||||||
|
// *ccol6 = c6;
|
||||||
|
// *ccol7 = c7;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn sgemm_sup_1x8_t(
|
||||||
|
k: usize,
|
||||||
|
alpha: f32,
|
||||||
|
a: *const f32,
|
||||||
|
lda: usize,
|
||||||
|
pb: *const f32,
|
||||||
|
beta: f32,
|
||||||
|
c: *mut f32,
|
||||||
|
ldc: usize,
|
||||||
|
) {
|
||||||
|
let mut c03 = _mm_setzero_ps();
|
||||||
|
let mut c4 = 0.0f32;
|
||||||
|
// let mut c5 = 0.0f32;
|
||||||
|
// let mut c6 = 0.0f32;
|
||||||
|
// let mut c7 = 0.0f32;
|
||||||
|
|
||||||
|
let mut a = a;
|
||||||
|
let mut pb = pb;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
let a0 = *a;
|
||||||
|
let a0_simd = _mm_set_ps1(a0);
|
||||||
|
|
||||||
|
c03 = fmadd_ps(_mm_loadu_ps(pb), a0_simd, c03);
|
||||||
|
c4 += *pb.add(4) * a0;
|
||||||
|
// c5 += *pb.add(5) * a0;
|
||||||
|
// c6 += *pb.add(6) * a0;
|
||||||
|
// c7 += *pb.add(7) * a0;
|
||||||
|
|
||||||
|
a = a.add(1);
|
||||||
|
pb = pb.add(NR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// c0 *= alpha;
|
||||||
|
// c1 *= alpha;
|
||||||
|
// c2 *= alpha;
|
||||||
|
// c3 *= alpha;
|
||||||
|
// c4 *= alpha;
|
||||||
|
// c5 *= alpha;
|
||||||
|
// c6 *= alpha;
|
||||||
|
// c7 *= alpha;
|
||||||
|
|
||||||
|
// let ccol0 = c;
|
||||||
|
// let ccol1 = c.add(ldc);
|
||||||
|
// let ccol2 = c.add(ldc * 2);
|
||||||
|
// let ccol3 = c.add(ldc * 3);
|
||||||
|
let ccol4 = c.add(ldc * 4);
|
||||||
|
// let ccol5 = c.add(ldc * 5);
|
||||||
|
// let ccol6 = c.add(ldc * 6);
|
||||||
|
// let ccol7 = c.add(ldc * 7);
|
||||||
|
|
||||||
|
// if beta != 0.0 {
|
||||||
|
// c0 += beta * *ccol0;
|
||||||
|
// c1 += beta * *ccol1;
|
||||||
|
// c2 += beta * *ccol2;
|
||||||
|
// c3 += beta * *ccol3;
|
||||||
|
// c4 += beta * *ccol4;
|
||||||
|
// // c5 += beta * *ccol5;
|
||||||
|
// // c6 += beta * *ccol6;
|
||||||
|
// // c7 += beta * *ccol7;
|
||||||
|
// }
|
||||||
|
|
||||||
|
_mm_storeu_ps(c, c03);
|
||||||
|
// *ccol0 = c0;
|
||||||
|
// *ccol1 = c1;
|
||||||
|
// *ccol2 = c2;
|
||||||
|
// *ccol3 = c3;
|
||||||
|
*ccol4 = c4;
|
||||||
|
// *ccol5 = c5;
|
||||||
|
// *ccol6 = c6;
|
||||||
|
// *ccol7 = c7;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||||
|
let mut bcol0 = b;
|
||||||
|
let mut bcol1 = b.add(ldb);
|
||||||
|
let mut bcol2 = b.add(ldb * 2);
|
||||||
|
let mut bcol3 = b.add(ldb * 3);
|
||||||
|
let mut bcol4 = b.add(ldb * 4);
|
||||||
|
// let mut bcol5 = b.add(ldb * 5);
|
||||||
|
// let mut bcol6 = b.add(ldb * 6);
|
||||||
|
// let mut bcol7 = b.add(ldb * 7);
|
||||||
|
|
||||||
|
let mut pb = pb;
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
_mm_storeu_ps(pb, _mm_set_ps(*bcol3, *bcol2, *bcol1, *bcol0));
|
||||||
|
|
||||||
|
// *pb = *bcol0;
|
||||||
|
// *pb.add(1) = *bcol1;
|
||||||
|
// *pb.add(2) = *bcol2;
|
||||||
|
// *pb.add(3) = *bcol3;
|
||||||
|
*pb.add(4) = *bcol4;
|
||||||
|
// *pb.add(5) = *bcol5;
|
||||||
|
// *pb.add(6) = *bcol6;
|
||||||
|
// *pb.add(7) = *bcol7;
|
||||||
|
|
||||||
|
bcol0 = bcol0.add(1);
|
||||||
|
bcol1 = bcol1.add(1);
|
||||||
|
bcol2 = bcol2.add(1);
|
||||||
|
bcol3 = bcol3.add(1);
|
||||||
|
bcol4 = bcol4.add(1);
|
||||||
|
// bcol5 = bcol5.add(1);
|
||||||
|
// bcol6 = bcol6.add(1);
|
||||||
|
// bcol7 = bcol7.add(1);
|
||||||
|
|
||||||
|
pb = pb.add(NR);
|
||||||
|
}
|
||||||
|
}
|
86
src/kernel/sse/mod.rs
Normal file
86
src/kernel/sse/mod.rs
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
mod fma;
|
||||||
|
// mod hsum;
|
||||||
|
mod intrinsics;
|
||||||
|
// pub mod l1d;
|
||||||
|
// pub mod l1s;
|
||||||
|
// pub mod l3d;
|
||||||
|
pub mod l3s;
|
||||||
|
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
use crate::matrix::{Number, MutMatrix, Matrix, MatrixMut};
|
||||||
|
use crate::kernel::{GemmKernel, GemmKernelSupNr, GemmKernelSupMr, GemmKernelSup};
|
||||||
|
use crate::dim::*;
|
||||||
|
|
||||||
|
|
||||||
|
pub struct SseKernel<F: Number, I>(PhantomData<fn(F, I)>);
|
||||||
|
|
||||||
|
impl<I> GemmKernelSupNr<f32, A5> for SseKernel<f32, I>
|
||||||
|
where I: GemmKernelSupNr<f32, A5>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_tr<A: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::sup_tr(alpha, a, pb, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernelSupMr<f32, A16> for SseKernel<f32, I>
|
||||||
|
where I: GemmKernelSupMr<f32, A16>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_bl<B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::sup_bl(alpha, pa, b, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernelSup<f32> for SseKernel<f32, I>
|
||||||
|
where I: GemmKernelSup<f32>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sup_br<A: Matrix<f32>, B: Matrix<f32>, C: MatrixMut<f32>>(
|
||||||
|
k: usize,
|
||||||
|
alpha: f32,
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::sup_br(k, alpha, a, b, beta, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
|
||||||
|
where I: GemmKernel<f32, A16, A5>
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
|
||||||
|
I::pack_row_a(a, pa, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
|
||||||
|
I::pack_row_b(b, pb, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn main_tl<C: MatrixMut<f32>>(
|
||||||
|
alpha: f32,
|
||||||
|
pa: MutMatrix<f32>,
|
||||||
|
pb: MutMatrix<f32>,
|
||||||
|
beta: f32,
|
||||||
|
c: C,
|
||||||
|
) {
|
||||||
|
I::main_tl(alpha, pa, pb, beta, c);
|
||||||
|
}
|
||||||
|
}
|
17
src/lib.rs
Normal file
17
src/lib.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
pub mod matrix;
|
||||||
|
pub mod gemm;
|
||||||
|
mod sgemm;
|
||||||
|
mod aligned_alloc;
|
||||||
|
pub mod kernel;
|
||||||
|
pub mod dim;
|
||||||
|
pub mod executor;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
extern crate blas;
|
||||||
|
#[cfg(test)]
|
||||||
|
extern crate openblas;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
|
||||||
|
pub use crate::sgemm::sgemm;
|
418
src/main.rs
418
src/main.rs
@ -1,373 +1,16 @@
|
|||||||
mod matrix;
|
mod matrix;
|
||||||
|
mod gemm;
|
||||||
mod sgemm;
|
mod sgemm;
|
||||||
mod aligned_alloc;
|
mod aligned_alloc;
|
||||||
mod kernel;
|
mod kernel;
|
||||||
|
mod dim;
|
||||||
|
mod executor;
|
||||||
|
|
||||||
|
|
||||||
extern crate blas;
|
extern crate blas;
|
||||||
extern crate openblas;
|
extern crate openblas;
|
||||||
|
|
||||||
use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
use crate::sgemm::sgemm;
|
||||||
|
|
||||||
pub unsafe fn sgemm(
|
|
||||||
transa: bool,
|
|
||||||
transb: bool,
|
|
||||||
transc: bool,
|
|
||||||
m: usize,
|
|
||||||
n: usize,
|
|
||||||
k: usize,
|
|
||||||
alpha: f32,
|
|
||||||
a: *const f32,
|
|
||||||
lda: usize,
|
|
||||||
b: *const f32,
|
|
||||||
ldb: usize,
|
|
||||||
beta: f32,
|
|
||||||
c: *mut f32,
|
|
||||||
ldc: usize,
|
|
||||||
) {
|
|
||||||
match (transa, transb, transc) {
|
|
||||||
(false, false, false) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstMatrix::new(a, lda),
|
|
||||||
ConstMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
(false, false, true) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstMatrix::new(a, lda),
|
|
||||||
ConstMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
(false, true, false) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstMatrix::new(a, lda),
|
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
(false, true, true) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstMatrix::new(a, lda),
|
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
(true, false, false) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstTransposedMatrix::new(a, lda),
|
|
||||||
ConstMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
(true, false, true) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstTransposedMatrix::new(a, lda),
|
|
||||||
ConstMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
|
|
||||||
(true, true, false) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstTransposedMatrix::new(a, lda),
|
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutMatrix::new(c, ldc)),
|
|
||||||
|
|
||||||
|
|
||||||
(true, true, true) => sgemm::sgemm(
|
|
||||||
m, n, k, alpha,
|
|
||||||
ConstTransposedMatrix::new(a, lda),
|
|
||||||
ConstTransposedMatrix::new(b, ldb),
|
|
||||||
beta,
|
|
||||||
MutTransposedMatrix::new(c, ldc)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsafe fn sgemm_ref_nn(
|
|
||||||
m: usize,
|
|
||||||
n: usize,
|
|
||||||
k: usize,
|
|
||||||
_alpha: f32,
|
|
||||||
a: *const f32,
|
|
||||||
lda: usize,
|
|
||||||
b: *const f32,
|
|
||||||
ldb: usize,
|
|
||||||
_beta: f32,
|
|
||||||
c: *mut f32,
|
|
||||||
ldc: usize,
|
|
||||||
) {
|
|
||||||
for j in 0..n {
|
|
||||||
for i in 0..m {
|
|
||||||
let mut ci = *c.add(i + j * ldc);
|
|
||||||
for p in 0..k {
|
|
||||||
ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
|
|
||||||
}
|
|
||||||
*c.add(i + j * ldc) = ci;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn sgemm_ref_nt(
|
|
||||||
m: usize,
|
|
||||||
n: usize,
|
|
||||||
k: usize,
|
|
||||||
_alpha: f32,
|
|
||||||
a: *const f32,
|
|
||||||
lda: usize,
|
|
||||||
b: *const f32,
|
|
||||||
ldb: usize,
|
|
||||||
_beta: f32,
|
|
||||||
c: *mut f32,
|
|
||||||
ldc: usize,
|
|
||||||
) {
|
|
||||||
for j in 0..n {
|
|
||||||
for i in 0..m {
|
|
||||||
let mut ci = *c.add(i + j * ldc);
|
|
||||||
|
|
||||||
for p in 0..k {
|
|
||||||
ci += *a.add(i + p * lda) * *b.add(j + p * ldb);
|
|
||||||
}
|
|
||||||
|
|
||||||
*c.add(i + j * ldc) = ci;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(never)]
|
|
||||||
pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
|
|
||||||
a: &[f32], lda: usize,
|
|
||||||
b: &[f32], ldb: usize,
|
|
||||||
_bata: f32,
|
|
||||||
c: &mut [f32], ldc: usize) {
|
|
||||||
|
|
||||||
let mk = m * k;
|
|
||||||
let nk = n * k;
|
|
||||||
let mn = m * n;
|
|
||||||
let a = &a[0..mk];
|
|
||||||
let b = &b[0..nk];
|
|
||||||
let c = &mut c[0..mn];
|
|
||||||
|
|
||||||
for i_m in 0..m {
|
|
||||||
for i_k in 0..k {
|
|
||||||
let a_part = alpha * a[i_m * lda + i_k];
|
|
||||||
|
|
||||||
for i_n in 0..n {
|
|
||||||
c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fn main2() {
|
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
const LEN: usize = 511;
|
|
||||||
let (m, n, k) = (LEN, LEN, LEN);
|
|
||||||
|
|
||||||
let mut a = vec![0.5; m * k];
|
|
||||||
let mut a_t = vec![0.5; m * k];
|
|
||||||
let mut b = vec![0.5; n * k];
|
|
||||||
let mut b_t = vec![0.5; n * k];
|
|
||||||
let mut c_nn = vec![0.0; m * n];
|
|
||||||
let mut c_nt = vec![0.0; m * n];
|
|
||||||
let mut c_tn = vec![0.0; m * n];
|
|
||||||
let mut c_tt = vec![0.0; m * n];
|
|
||||||
let mut cref = vec![0.0; m * n];
|
|
||||||
|
|
||||||
for i in 0..m {
|
|
||||||
for j in 0..k {
|
|
||||||
let v = rng.gen();
|
|
||||||
a[i + j * m] = v;
|
|
||||||
a_t[j + i * m] = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
for j in 0..k {
|
|
||||||
let v = rng.gen();
|
|
||||||
b[i + j * n] = v;
|
|
||||||
b_t[j + i * n] = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// let time = std::time::Instant::now();
|
|
||||||
// unsafe {
|
|
||||||
// gemm_nn(
|
|
||||||
// LEN,
|
|
||||||
// LEN,
|
|
||||||
// LEN,
|
|
||||||
// 1.0,
|
|
||||||
// a.as_slice(),
|
|
||||||
// LEN,
|
|
||||||
// b.as_slice(),
|
|
||||||
// LEN,
|
|
||||||
// 1.0,
|
|
||||||
// cref1.as_mut_slice(),
|
|
||||||
// LEN,
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
|
|
||||||
// println!("Naive (mine) {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
|
||||||
unsafe {
|
|
||||||
sgemm_ref_nn(
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
a.as_ptr(),
|
|
||||||
m,
|
|
||||||
|
|
||||||
b.as_ptr(),
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
cref.as_mut_ptr(),
|
|
||||||
m,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Naive {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
|
||||||
unsafe {
|
|
||||||
sgemm(
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
a.as_ptr(),
|
|
||||||
m,
|
|
||||||
|
|
||||||
b.as_ptr(),
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
c_nn.as_mut_ptr(),
|
|
||||||
m,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
println!("[NN] Optimized {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
for i in 0..LEN {
|
|
||||||
for j in 0..LEN {
|
|
||||||
let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]);
|
|
||||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
|
||||||
unsafe {
|
|
||||||
sgemm(
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
a.as_ptr(),
|
|
||||||
m,
|
|
||||||
|
|
||||||
b_t.as_ptr(),
|
|
||||||
n,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
c_nt.as_mut_ptr(),
|
|
||||||
m,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("[NT] Optimized {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
for i in 0..LEN {
|
|
||||||
for j in 0..LEN {
|
|
||||||
let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]);
|
|
||||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
|
||||||
unsafe {
|
|
||||||
sgemm(
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
a_t.as_ptr(),
|
|
||||||
k,
|
|
||||||
|
|
||||||
b.as_ptr(),
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
c_tn.as_mut_ptr(),
|
|
||||||
m,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("[TN] Optimized {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
for i in 0..LEN {
|
|
||||||
for j in 0..LEN {
|
|
||||||
let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]);
|
|
||||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
|
||||||
unsafe {
|
|
||||||
sgemm(
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
a_t.as_ptr(),
|
|
||||||
k,
|
|
||||||
|
|
||||||
b_t.as_ptr(),
|
|
||||||
n,
|
|
||||||
|
|
||||||
1.0,
|
|
||||||
c_tt.as_mut_ptr(),
|
|
||||||
m,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
println!("[TT] Optimized {}", time.elapsed().as_millis());
|
|
||||||
|
|
||||||
|
|
||||||
for i in 0..LEN {
|
|
||||||
for j in 0..LEN {
|
|
||||||
let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]);
|
|
||||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
@ -395,25 +38,24 @@ fn main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
let time = std::time::Instant::now();
|
let time = std::time::Instant::now();
|
||||||
unsafe {
|
unsafe {
|
||||||
// blas::sgemm(
|
blas::sgemm(
|
||||||
// b'N',
|
b'N',
|
||||||
// b'N',
|
b'N',
|
||||||
// m as i32,
|
m as i32,
|
||||||
// n as i32,
|
n as i32,
|
||||||
// k as i32,
|
k as i32,
|
||||||
// 1.0,
|
1.0,
|
||||||
// a.as_slice(),
|
a.as_slice(),
|
||||||
// m as i32,
|
m as i32,
|
||||||
// b.as_slice(),
|
b.as_slice(),
|
||||||
// k as i32,
|
k as i32,
|
||||||
// 1.0,
|
1.0,
|
||||||
// cref1.as_mut_slice(),
|
cref1.as_mut_slice(),
|
||||||
// m as i32,
|
m as i32,
|
||||||
// );
|
);
|
||||||
|
|
||||||
// gemm_nn(
|
// gemm_nn(
|
||||||
// m,
|
// m,
|
||||||
// n,
|
// n,
|
||||||
@ -456,6 +98,7 @@ fn main() {
|
|||||||
let time = std::time::Instant::now();
|
let time = std::time::Instant::now();
|
||||||
unsafe {
|
unsafe {
|
||||||
sgemm(
|
sgemm(
|
||||||
|
&executor::RayonExecutor,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
@ -477,22 +120,5 @@ fn main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
println!("Optimized {}", time.elapsed().as_millis());
|
println!("Optimized {}", time.elapsed().as_millis());
|
||||||
|
|
||||||
|
|
||||||
// for i in 0..LEN {
|
|
||||||
// for j in 0..LEN {
|
|
||||||
// let (a, b) = (c[i + j * LEN], cref[i + j * LEN]);
|
|
||||||
// assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn feq(a: f32, b: f32) -> bool {
|
|
||||||
if a == b {
|
|
||||||
true
|
|
||||||
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
|
|
||||||
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
|
|
||||||
} else {
|
|
||||||
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
322
src/matrix.rs
322
src/matrix.rs
@ -1,36 +1,50 @@
|
|||||||
pub(crate) trait Matrix: Copy {
|
|
||||||
|
pub trait Number: Copy + Send + Sync + 'static {}
|
||||||
|
|
||||||
|
impl Number for f32 {}
|
||||||
|
impl Number for f64 {}
|
||||||
|
|
||||||
|
pub trait Matrix<F: Number>: Copy + Send + Sync + 'static {
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self;
|
unsafe fn sub(&self, row: usize, col: usize) -> Self;
|
||||||
|
unsafe fn sub_col(&self, col: usize) -> Self;
|
||||||
|
unsafe fn sub_row(&self, row: usize) -> Self;
|
||||||
unsafe fn is_transposed(&self) -> bool;
|
unsafe fn is_transposed(&self) -> bool;
|
||||||
unsafe fn stride(&self) -> usize;
|
unsafe fn stride(&self) -> usize;
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> f32;
|
unsafe fn get(&self, row: usize, col: usize) -> F;
|
||||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32;
|
unsafe fn index(&self, row: usize, col: usize) -> *const F;
|
||||||
unsafe fn row(&self, row: usize) -> *const f32;
|
unsafe fn row(&self, row: usize) -> *const F;
|
||||||
unsafe fn col(&self, col: usize) -> *const f32;
|
unsafe fn col(&self, col: usize) -> *const F;
|
||||||
|
unsafe fn ptr(&self) -> *const F;
|
||||||
|
unsafe fn inc_row(&mut self);
|
||||||
|
unsafe fn inc_col(&mut self);
|
||||||
|
unsafe fn shift_row(&mut self, rows: usize);
|
||||||
|
unsafe fn shift_col(&mut self, cols: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait MatrixMut: Matrix {
|
pub trait MatrixMut<F: Number>: Matrix<F> {
|
||||||
unsafe fn set(&self, row: usize, col: usize, val: f32);
|
unsafe fn set(&self, row: usize, col: usize, val: F);
|
||||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32;
|
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F;
|
||||||
unsafe fn row_mut(&self, row: usize) -> *mut f32;
|
unsafe fn row_mut(&self, row: usize) -> *mut F;
|
||||||
unsafe fn col_mut(&self, col: usize) -> *mut f32;
|
unsafe fn col_mut(&self, col: usize) -> *mut F;
|
||||||
|
unsafe fn ptr_mut(&self) -> *mut F;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub(crate) struct ConstMatrix {
|
pub struct ConstMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *const f32
|
pub ptr: *const F
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for ConstMatrix {}
|
unsafe impl<F: Number> Send for ConstMatrix<F> {}
|
||||||
unsafe impl Sync for ConstMatrix {}
|
unsafe impl<F: Number> Sync for ConstMatrix<F> {}
|
||||||
|
|
||||||
impl ConstMatrix {
|
impl<F: Number> ConstMatrix<F> {
|
||||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
pub fn new(ptr: *const F, stride: usize) -> Self {
|
||||||
Self { ptr, stride }
|
Self { ptr, stride }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix for ConstMatrix {
|
impl<F: Number> Matrix<F> for ConstMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -39,6 +53,22 @@ impl Matrix for ConstMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.col(col),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.row(row),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
false
|
false
|
||||||
@ -50,42 +80,67 @@ impl Matrix for ConstMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||||
*self.ptr.add(row * self.stride + col)
|
*self.ptr.add(row * self.stride + col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||||
self.ptr.add(row * self.stride + col)
|
self.ptr.add(row * self.stride + col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
unsafe fn row(&self, row: usize) -> *const F {
|
||||||
self.ptr.add(row * self.stride)
|
self.ptr.add(row * self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
unsafe fn col(&self, col: usize) -> *const F {
|
||||||
self.ptr.add(col)
|
self.ptr.add(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr(&self) -> *const F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_row(&mut self) {
|
||||||
|
self.ptr = self.row(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_col(&mut self) {
|
||||||
|
self.ptr = self.col(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_row(&mut self, rows: usize) {
|
||||||
|
self.ptr = self.row(rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_col(&mut self, cols: usize) {
|
||||||
|
self.ptr = self.col(cols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub(crate) struct ConstTransposedMatrix {
|
pub struct ConstTransposedMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *const f32
|
pub ptr: *const F
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for ConstTransposedMatrix {}
|
unsafe impl<F: Number> Send for ConstTransposedMatrix<F> {}
|
||||||
unsafe impl Sync for ConstTransposedMatrix {}
|
unsafe impl<F: Number> Sync for ConstTransposedMatrix<F> {}
|
||||||
|
|
||||||
impl ConstTransposedMatrix {
|
impl<F: Number> ConstTransposedMatrix<F> {
|
||||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
pub fn new(ptr: *const F, stride: usize) -> Self {
|
||||||
Self { ptr, stride }
|
Self { ptr, stride }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix for ConstTransposedMatrix {
|
impl<F: Number> Matrix<F> for ConstTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -94,6 +149,22 @@ impl Matrix for ConstTransposedMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.col(col),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.row(row),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
true
|
true
|
||||||
@ -105,42 +176,67 @@ impl Matrix for ConstTransposedMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||||
*self.ptr.add(col * self.stride + row)
|
*self.ptr.add(col * self.stride + row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||||
self.ptr.add(col * self.stride + row)
|
self.ptr.add(col * self.stride + row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
unsafe fn row(&self, row: usize) -> *const F {
|
||||||
self.ptr.add(row)
|
self.ptr.add(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
unsafe fn col(&self, col: usize) -> *const F {
|
||||||
self.ptr.add(col * self.stride)
|
self.ptr.add(col * self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr(&self) -> *const F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_row(&mut self) {
|
||||||
|
self.ptr = self.row(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_col(&mut self) {
|
||||||
|
self.ptr = self.col(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_row(&mut self, rows: usize) {
|
||||||
|
self.ptr = self.row(rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_col(&mut self, cols: usize) {
|
||||||
|
self.ptr = self.col(cols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub(crate) struct MutMatrix {
|
pub struct MutMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *mut f32
|
pub ptr: *mut F
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for MutMatrix {}
|
unsafe impl<F: Number> Send for MutMatrix<F> {}
|
||||||
unsafe impl Sync for MutMatrix {}
|
unsafe impl<F: Number> Sync for MutMatrix<F> {}
|
||||||
|
|
||||||
impl MutMatrix {
|
impl<F: Number> MutMatrix<F> {
|
||||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
pub fn new(ptr: *mut F, stride: usize) -> Self {
|
||||||
Self { ptr, stride }
|
Self { ptr, stride }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix for MutMatrix {
|
impl<F: Number> Matrix<F> for MutMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -149,6 +245,22 @@ impl Matrix for MutMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.col_mut(col),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.row_mut(row),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
false
|
false
|
||||||
@ -160,64 +272,94 @@ impl Matrix for MutMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||||
*self.ptr.add(row * self.stride + col)
|
*self.ptr.add(row * self.stride + col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||||
self.ptr.add(row * self.stride + col)
|
self.ptr.add(row * self.stride + col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
unsafe fn row(&self, row: usize) -> *const F {
|
||||||
self.ptr.add(row * self.stride)
|
self.ptr.add(row * self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
unsafe fn col(&self, col: usize) -> *const F {
|
||||||
self.ptr.add(col)
|
self.ptr.add(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr(&self) -> *const F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_row(&mut self) {
|
||||||
|
self.ptr = self.row_mut(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_col(&mut self) {
|
||||||
|
self.ptr = self.col_mut(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_row(&mut self, rows: usize) {
|
||||||
|
self.ptr = self.row_mut(rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_col(&mut self, cols: usize) {
|
||||||
|
self.ptr = self.col_mut(cols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MatrixMut for MutMatrix {
|
impl<F: Number> MatrixMut<F> for MutMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||||
*self.ptr.add(row * self.stride + col) = value;
|
*self.ptr.add(row * self.stride + col) = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
|
||||||
self.ptr.add(row * self.stride + col)
|
self.ptr.add(row * self.stride + col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
unsafe fn row_mut(&self, row: usize) -> *mut F {
|
||||||
self.ptr.add(row * self.stride)
|
self.ptr.add(row * self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
unsafe fn col_mut(&self, col: usize) -> *mut F {
|
||||||
self.ptr.add(col)
|
self.ptr.add(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr_mut(&self) -> *mut F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub(crate) struct MutTransposedMatrix {
|
pub struct MutTransposedMatrix<F: Number> {
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub ptr: *mut f32
|
pub ptr: *mut F
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for MutTransposedMatrix {}
|
unsafe impl<F: Number> Send for MutTransposedMatrix<F> {}
|
||||||
unsafe impl Sync for MutTransposedMatrix {}
|
unsafe impl<F: Number> Sync for MutTransposedMatrix<F> {}
|
||||||
|
|
||||||
impl MutTransposedMatrix {
|
impl<F: Number> MutTransposedMatrix<F> {
|
||||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
pub fn new(ptr: *mut F, stride: usize) -> Self {
|
||||||
Self { ptr, stride }
|
Self { ptr, stride }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix for MutTransposedMatrix {
|
impl<F: Number> Matrix<F> for MutTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -226,6 +368,22 @@ impl Matrix for MutTransposedMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_col(&self, col: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.col_mut(col),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn sub_row(&self, row: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
ptr: self.row_mut(row),
|
||||||
|
stride: self.stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn is_transposed(&self) -> bool {
|
unsafe fn is_transposed(&self) -> bool {
|
||||||
true
|
true
|
||||||
@ -237,47 +395,77 @@ impl Matrix for MutTransposedMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
unsafe fn get(&self, row: usize, col: usize) -> F {
|
||||||
*self.ptr.add(col * self.stride + row)
|
*self.ptr.add(col * self.stride + row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
unsafe fn index(&self, row: usize, col: usize) -> *const F {
|
||||||
self.ptr.add(col * self.stride + row) as *const f32
|
self.ptr.add(col * self.stride + row) as *const F
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
unsafe fn row(&self, row: usize) -> *const F {
|
||||||
self.ptr.add(row) as *const f32
|
self.ptr.add(row) as *const F
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
unsafe fn col(&self, col: usize) -> *const F {
|
||||||
self.ptr.add(col * self.stride) as *const f32
|
self.ptr.add(col * self.stride) as *const F
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr(&self) -> *const F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_row(&mut self) {
|
||||||
|
self.ptr = self.row_mut(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn inc_col(&mut self) {
|
||||||
|
self.ptr = self.col_mut(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_row(&mut self, rows: usize) {
|
||||||
|
self.ptr = self.row_mut(rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn shift_col(&mut self, cols: usize) {
|
||||||
|
self.ptr = self.col_mut(cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl MatrixMut for MutTransposedMatrix {
|
impl<F: Number> MatrixMut<F> for MutTransposedMatrix<F> {
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
unsafe fn set(&self, row: usize, col: usize, value: F) {
|
||||||
*self.ptr.add(col * self.stride + row) = value;
|
*self.ptr.add(col * self.stride + row) = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut F {
|
||||||
self.ptr.add(col * self.stride + row)
|
self.ptr.add(col * self.stride + row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
unsafe fn row_mut(&self, row: usize) -> *mut F {
|
||||||
self.ptr.add(row)
|
self.ptr.add(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
unsafe fn col_mut(&self, col: usize) -> *mut F {
|
||||||
self.ptr.add(col * self.stride)
|
self.ptr.add(col * self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
unsafe fn ptr_mut(&self) -> *mut F {
|
||||||
|
self.ptr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
287
src/sgemm.rs
287
src/sgemm.rs
@ -1,273 +1,30 @@
|
|||||||
use crate::aligned_alloc;
|
|
||||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
|
||||||
use crate::kernel::params::single::*;
|
|
||||||
use crate::kernel;
|
|
||||||
|
|
||||||
|
use crate::kernel::avx::AvxKernel;
|
||||||
|
use crate::kernel::generic::GenericKernel;
|
||||||
|
use crate::dim::*;
|
||||||
|
use crate::executor::Executor;
|
||||||
|
|
||||||
pub(crate) unsafe fn sgemm<A, B, C>(
|
pub unsafe fn sgemm<E: Executor>(
|
||||||
|
e: &E,
|
||||||
|
transa: bool,
|
||||||
|
transb: bool,
|
||||||
|
transc: bool,
|
||||||
m: usize,
|
m: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
k: usize,
|
k: usize,
|
||||||
alpha: f32,
|
alpha: f32,
|
||||||
a: A,
|
a: *const f32,
|
||||||
b: B,
|
lda: usize,
|
||||||
|
b: *const f32,
|
||||||
|
ldb: usize,
|
||||||
beta: f32,
|
beta: f32,
|
||||||
c: C
|
c: *mut f32,
|
||||||
)
|
ldc: usize,
|
||||||
where A: Matrix,
|
) {
|
||||||
B: Matrix,
|
crate::gemm::gemm::<E, f32, AvxKernel<f32, GenericKernel>, A16, A5>(
|
||||||
C: MatrixMut,
|
e,
|
||||||
{
|
transa, transb, transc,
|
||||||
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<f32>());
|
m, n, k, alpha,
|
||||||
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<f32>());
|
a, lda, b, ldb,
|
||||||
|
beta, c, ldc);
|
||||||
for j in (0..n).step_by(NC) {
|
|
||||||
let j_b = std::cmp::min(n - j, NC);
|
|
||||||
for p in (0..k).step_by(KC) {
|
|
||||||
let p_b = std::cmp::min(k - p, KC);
|
|
||||||
for i in (0..m).step_by(MC) {
|
|
||||||
let i_b = std::cmp::min(m - i, MC);
|
|
||||||
|
|
||||||
let pa = MutMatrix::new(packed_a.ptr_f32(), p_b);
|
|
||||||
let pb = MutMatrix::new(packed_b.ptr_f32(), p_b);
|
|
||||||
|
|
||||||
inner_kernel(
|
|
||||||
i_b,
|
|
||||||
j_b,
|
|
||||||
p_b,
|
|
||||||
alpha,
|
|
||||||
a.sub(p, i),
|
|
||||||
b.sub(j, p),
|
|
||||||
beta,
|
|
||||||
c.sub(j, i),
|
|
||||||
pa,
|
|
||||||
pb,
|
|
||||||
i == 0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// +----------------------+
|
|
||||||
// |A : : : |B |
|
|
||||||
// | : : : | |
|
|
||||||
// + - - - - - - - - - ---+
|
|
||||||
// | : : : | |
|
|
||||||
// | : : : | |
|
|
||||||
// + - - - - - - - - - ---+
|
|
||||||
// | : : : | |
|
|
||||||
// | : : : | |
|
|
||||||
// +----------------------+
|
|
||||||
// |C | | | |D |
|
|
||||||
// +----------------------+
|
|
||||||
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
unsafe fn sgemm_pa<A: Matrix>(a: A, pa: MutMatrix, i: usize) {
|
|
||||||
if a.is_transposed() {
|
|
||||||
kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
|
||||||
} else {
|
|
||||||
kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
unsafe fn sgemm_pb<B: Matrix>(b: B, pb: MutMatrix, j: usize) {
|
|
||||||
if b.is_transposed() {
|
|
||||||
kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
|
||||||
} else {
|
|
||||||
kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
unsafe fn sgemm_ukr<C: MatrixMut>(
|
|
||||||
i: usize, j: usize,
|
|
||||||
alpha: f32,
|
|
||||||
pa: MutMatrix,
|
|
||||||
pb: MutMatrix,
|
|
||||||
beta: f32,
|
|
||||||
c: C)
|
|
||||||
{
|
|
||||||
if c.is_transposed() {
|
|
||||||
unimplemented!()
|
|
||||||
} else {
|
|
||||||
kernel::sgemm_ukr(
|
|
||||||
pa.stride,
|
|
||||||
alpha,
|
|
||||||
pa.row(i),
|
|
||||||
pb.row(j),
|
|
||||||
beta,
|
|
||||||
c.index_mut(j, i),
|
|
||||||
c.stride())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
unsafe fn sgemm_sup1<A: Matrix, C: MatrixMut>(
|
|
||||||
i: usize, j: usize,
|
|
||||||
alpha: f32,
|
|
||||||
a: A,
|
|
||||||
pb: MutMatrix,
|
|
||||||
beta: f32,
|
|
||||||
c: C)
|
|
||||||
{
|
|
||||||
if c.is_transposed() {
|
|
||||||
unimplemented!()
|
|
||||||
} else {
|
|
||||||
if a.is_transposed() {
|
|
||||||
kernel::sgemm_sup1_t(
|
|
||||||
pb.stride,
|
|
||||||
alpha,
|
|
||||||
a.col(i),
|
|
||||||
a.stride(),
|
|
||||||
pb.row(j),
|
|
||||||
beta,
|
|
||||||
c.index_mut(j, i),
|
|
||||||
c.stride())
|
|
||||||
} else {
|
|
||||||
kernel::sgemm_sup1(
|
|
||||||
pb.stride,
|
|
||||||
alpha,
|
|
||||||
a.col(i),
|
|
||||||
a.stride(),
|
|
||||||
pb.row(j),
|
|
||||||
beta,
|
|
||||||
c.index_mut(j, i),
|
|
||||||
c.stride())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
unsafe fn sgemm_sup0<B: Matrix, C: MatrixMut>(
|
|
||||||
i: usize, j: usize,
|
|
||||||
alpha: f32,
|
|
||||||
pa: MutMatrix,
|
|
||||||
b: B,
|
|
||||||
beta: f32,
|
|
||||||
c: C)
|
|
||||||
{
|
|
||||||
if c.is_transposed() {
|
|
||||||
unimplemented!()
|
|
||||||
} else {
|
|
||||||
if b.is_transposed() {
|
|
||||||
kernel::sgemm_sup0_t(
|
|
||||||
pa.stride,
|
|
||||||
alpha,
|
|
||||||
pa.row(i),
|
|
||||||
b.row(j),
|
|
||||||
b.stride(),
|
|
||||||
beta,
|
|
||||||
c.index_mut(j, i))
|
|
||||||
} else {
|
|
||||||
kernel::sgemm_sup0(
|
|
||||||
pa.stride,
|
|
||||||
alpha,
|
|
||||||
pa.row(i),
|
|
||||||
b.row(j),
|
|
||||||
beta,
|
|
||||||
c.index_mut(j, i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn inner_kernel<A, B, C>(
|
|
||||||
m: usize,
|
|
||||||
n: usize,
|
|
||||||
k: usize,
|
|
||||||
alpha: f32,
|
|
||||||
a: A,
|
|
||||||
b: B,
|
|
||||||
beta: f32,
|
|
||||||
c: C,
|
|
||||||
pa: MutMatrix,
|
|
||||||
pb: MutMatrix,
|
|
||||||
first_time: bool,
|
|
||||||
)
|
|
||||||
where A: Matrix,
|
|
||||||
B: Matrix,
|
|
||||||
C: MatrixMut,
|
|
||||||
|
|
||||||
{
|
|
||||||
let n_left = n % NR;
|
|
||||||
let n_main = n - n_left;
|
|
||||||
|
|
||||||
let m_left = m % MR;
|
|
||||||
let m_main = m - m_left;
|
|
||||||
|
|
||||||
if first_time {
|
|
||||||
for j in (0..n_main).step_by(NR) {
|
|
||||||
sgemm_pb(b, pb, j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in (0..m_main).step_by(MR) {
|
|
||||||
sgemm_pa(a, pa, i);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
for j in (0..n_main).step_by(NR) {
|
|
||||||
|
|
||||||
// Section A
|
|
||||||
for i in (0..m_main).step_by(MR) {
|
|
||||||
sgemm_ukr(
|
|
||||||
i,
|
|
||||||
j,
|
|
||||||
alpha,
|
|
||||||
pa,
|
|
||||||
pb,
|
|
||||||
beta,
|
|
||||||
c,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Section B
|
|
||||||
for i in m_main..m {
|
|
||||||
sgemm_sup1(
|
|
||||||
i,
|
|
||||||
j,
|
|
||||||
alpha,
|
|
||||||
a,
|
|
||||||
pb,
|
|
||||||
beta,
|
|
||||||
c,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for j in n_main..n {
|
|
||||||
|
|
||||||
// Section C
|
|
||||||
for i in (0..m_main).step_by(MR) {
|
|
||||||
sgemm_sup0(
|
|
||||||
i,
|
|
||||||
j,
|
|
||||||
alpha,
|
|
||||||
pa,
|
|
||||||
b,
|
|
||||||
beta,
|
|
||||||
c,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Section D
|
|
||||||
for i in m_main..m {
|
|
||||||
let mut elem = 0.0;
|
|
||||||
|
|
||||||
for p in 0..k {
|
|
||||||
elem += a.get(p, i) * b.get(j, p);
|
|
||||||
}
|
|
||||||
|
|
||||||
elem *= alpha;
|
|
||||||
|
|
||||||
if beta != 0.0 {
|
|
||||||
elem += beta * c.get(j, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
c.set(j, i, elem);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
308
src/test/l3s.rs
Normal file
308
src/test/l3s.rs
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
const M_LEN: usize = 8192;
|
||||||
|
const N_LEN: usize = 8192;
|
||||||
|
const K_LEN: usize = 8192;
|
||||||
|
|
||||||
|
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
|
||||||
|
let mut a = vec![0.0; m * k];
|
||||||
|
let mut a_t = vec![0.0; m * k];
|
||||||
|
|
||||||
|
let mut b = vec![0.0; n * k];
|
||||||
|
let mut b_t = vec![0.0; n * k];
|
||||||
|
|
||||||
|
for row in 0..k {
|
||||||
|
for col in 0..m {
|
||||||
|
let v = rng.gen();
|
||||||
|
a[row * m + col] = v;
|
||||||
|
a_t[col * k + row] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in 0..n {
|
||||||
|
for col in 0..k {
|
||||||
|
let v = rng.gen();
|
||||||
|
b[row * k + col] = v;
|
||||||
|
b_t[col * n + row] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(a, a_t, b, b_t)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
unsafe fn sgemm_ref(
|
||||||
|
m: usize,
|
||||||
|
n: usize,
|
||||||
|
k: usize,
|
||||||
|
_alpha: f32,
|
||||||
|
a: *const f32,
|
||||||
|
lda: usize,
|
||||||
|
b: *const f32,
|
||||||
|
ldb: usize,
|
||||||
|
_beta: f32,
|
||||||
|
c: *mut f32,
|
||||||
|
ldc: usize,
|
||||||
|
) {
|
||||||
|
for i_m in 0..m {
|
||||||
|
for i_k in 0..k {
|
||||||
|
let a_part = *a.add(i_m * lda + i_k);
|
||||||
|
|
||||||
|
for i_n in 0..n {
|
||||||
|
*c.add(i_m * ldc + i_n) += a_part * *b.add(i_k * ldb + i_n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for j in 0..n {
|
||||||
|
// for i in 0..m {
|
||||||
|
// let mut ci = *c.add(i + j * ldc);
|
||||||
|
|
||||||
|
// for p in 0..k {
|
||||||
|
// ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// *c.add(i + j * ldc) = ci;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sgemm_nn() {
|
||||||
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
let (a, _, b, _) = make_matrices();
|
||||||
|
|
||||||
|
let mut c = vec![0.0; m * n];
|
||||||
|
let mut cref = vec![0.0; m * n];
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
blas::sgemm(
|
||||||
|
b'N',
|
||||||
|
b'N',
|
||||||
|
m as i32,
|
||||||
|
n as i32,
|
||||||
|
k as i32,
|
||||||
|
1.0,
|
||||||
|
a.as_slice(),
|
||||||
|
m as i32,
|
||||||
|
b.as_slice(),
|
||||||
|
k as i32,
|
||||||
|
0.0,
|
||||||
|
cref.as_mut_slice(),
|
||||||
|
m as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::sgemm(
|
||||||
|
&crate::executor::DefaultExecutor,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
1.0,
|
||||||
|
a.as_ptr(),
|
||||||
|
m,
|
||||||
|
b.as_ptr(),
|
||||||
|
k,
|
||||||
|
0.0,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
m,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in 0..N_LEN {
|
||||||
|
for col in 0..M_LEN {
|
||||||
|
let index = row * M_LEN + col;
|
||||||
|
|
||||||
|
let (a, b) = (c[index], cref[index]);
|
||||||
|
|
||||||
|
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sgemm_nt() {
|
||||||
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
let (a, _, b, b_t) = make_matrices();
|
||||||
|
|
||||||
|
let mut c = vec![0.0; m * n];
|
||||||
|
let mut cref = vec![0.0; m * n];
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
blas::sgemm(
|
||||||
|
b'N',
|
||||||
|
b'T',
|
||||||
|
m as i32,
|
||||||
|
n as i32,
|
||||||
|
k as i32,
|
||||||
|
1.0,
|
||||||
|
a.as_slice(),
|
||||||
|
m as i32,
|
||||||
|
b_t.as_slice(),
|
||||||
|
n as i32,
|
||||||
|
0.0,
|
||||||
|
cref.as_mut_slice(),
|
||||||
|
m as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::sgemm(
|
||||||
|
&crate::executor::DefaultExecutor,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
1.0,
|
||||||
|
a.as_ptr(),
|
||||||
|
m,
|
||||||
|
b_t.as_ptr(),
|
||||||
|
n,
|
||||||
|
0.0,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
m,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in 0..N_LEN {
|
||||||
|
for col in 0..M_LEN {
|
||||||
|
let index = row * M_LEN + col;
|
||||||
|
let (a, b) = (c[index], cref[index]);
|
||||||
|
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sgemm_tn() {
|
||||||
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
let (a, a_t, b, _) = make_matrices();
|
||||||
|
|
||||||
|
let mut c = vec![0.0; m * n];
|
||||||
|
let mut cref = vec![0.0; m * n];
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
blas::sgemm(
|
||||||
|
b'T',
|
||||||
|
b'N',
|
||||||
|
m as i32,
|
||||||
|
n as i32,
|
||||||
|
k as i32,
|
||||||
|
1.0,
|
||||||
|
a_t.as_slice(),
|
||||||
|
k as i32,
|
||||||
|
b.as_slice(),
|
||||||
|
k as i32,
|
||||||
|
0.0,
|
||||||
|
cref.as_mut_slice(),
|
||||||
|
m as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::sgemm(
|
||||||
|
&crate::executor::DefaultExecutor,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
1.0,
|
||||||
|
a_t.as_ptr(),
|
||||||
|
k,
|
||||||
|
b.as_ptr(),
|
||||||
|
k,
|
||||||
|
0.0,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
m,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in 0..N_LEN {
|
||||||
|
for col in 0..M_LEN {
|
||||||
|
let index = row * M_LEN + col;
|
||||||
|
let (a, b) = (c[index], cref[index]);
|
||||||
|
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sgemm_tt() {
|
||||||
|
let (m, n, k) = (M_LEN, N_LEN, K_LEN);
|
||||||
|
let (a, a_t, b, b_t) = make_matrices();
|
||||||
|
|
||||||
|
let mut c = vec![0.0; m * n];
|
||||||
|
let mut cref = vec![0.0; m * n];
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
blas::sgemm(
|
||||||
|
b'T',
|
||||||
|
b'T',
|
||||||
|
m as i32,
|
||||||
|
n as i32,
|
||||||
|
k as i32,
|
||||||
|
1.0,
|
||||||
|
a_t.as_slice(),
|
||||||
|
k as i32,
|
||||||
|
b_t.as_slice(),
|
||||||
|
n as i32,
|
||||||
|
0.0,
|
||||||
|
cref.as_mut_slice(),
|
||||||
|
m as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::sgemm(
|
||||||
|
&crate::executor::DefaultExecutor,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
1.0,
|
||||||
|
a_t.as_ptr(),
|
||||||
|
k,
|
||||||
|
b_t.as_ptr(),
|
||||||
|
n,
|
||||||
|
0.0,
|
||||||
|
c.as_mut_ptr(),
|
||||||
|
m,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in 0..N_LEN {
|
||||||
|
for col in 0..M_LEN {
|
||||||
|
let index = row * M_LEN + col;
|
||||||
|
let (a, b) = (c[index], cref[index]);
|
||||||
|
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn feq(a: f32, b: f32) -> bool {
|
||||||
|
if a == b {
|
||||||
|
true
|
||||||
|
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
|
||||||
|
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
|
||||||
|
} else {
|
||||||
|
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
||||||
|
}
|
||||||
|
}
|
1
src/test/mod.rs
Normal file
1
src/test/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
mod l3s;
|
2
stats.sh
2
stats.sh
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
|
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
|
||||||
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
|
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
|
||||||
cargo clean && cargo build --release
|
#cargo clean && cargo build --release
|
||||||
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm
|
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm
|
||||||
|
Loading…
Reference in New Issue
Block a user