Working state
This commit is contained in:
commit
ffae99bafd
2
.cargo/config
Normal file
2
.cargo/config
Normal file
@ -0,0 +1,2 @@
|
||||
[build]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
**/*.rs.bk
|
164
Cargo.lock
generated
Normal file
164
Cargo.lock
generated
Normal file
@ -0,0 +1,164 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "blas"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas-sys"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "c2-chacha"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.65"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas"
|
||||
version = "0.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[metadata]
|
||||
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
|
||||
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
|
||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
||||
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
|
||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
|
||||
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
14
Cargo.toml
Normal file
14
Cargo.toml
Normal file
@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "gemm"
|
||||
version = "0.1.0"
|
||||
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
|
||||
edition = "2018"
|
||||
build = "build.rs"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
rand = "0.7.2"
|
||||
matrixmultiply = "0.2.3"
|
||||
blas = "0.20.0"
|
||||
openblas = "0.0.1"
|
7
build.rs
Normal file
7
build.rs
Normal file
@ -0,0 +1,7 @@
|
||||
fn main() {
|
||||
// Always linking openblas
|
||||
// Compiling blas just for testing is tedious -- install it on your system
|
||||
// and run this.
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rustc-link-lib={}=openblas", "dylib");
|
||||
}
|
31
src/aligned_alloc.rs
Normal file
31
src/aligned_alloc.rs
Normal file
@ -0,0 +1,31 @@
|
||||
use std::alloc;
|
||||
|
||||
pub struct Alloc {
|
||||
ptr: *mut u8,
|
||||
layout: alloc::Layout,
|
||||
}
|
||||
|
||||
impl Alloc {
|
||||
pub fn new(size: usize) -> Alloc {
|
||||
const ALIGN: usize = 32;
|
||||
let layout = alloc::Layout::from_size_align(size, ALIGN).unwrap();
|
||||
let ptr = unsafe { alloc::alloc(layout) };
|
||||
Alloc { ptr, layout }
|
||||
}
|
||||
|
||||
pub fn ptr_f32(&self) -> *mut f32 {
|
||||
self.ptr as *mut f32
|
||||
}
|
||||
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Alloc {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
alloc::dealloc(self.ptr, self.layout);
|
||||
}
|
||||
}
|
||||
}
|
47
src/kernel/avx/fma.rs
Normal file
47
src/kernel/avx/fma.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use super::intrinsics::*;
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
|
||||
_mm256_fmadd_ps(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
|
||||
_mm256_add_ps(_mm256_mul_ps(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
|
||||
_mm256_fmsub_ps(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
pub unsafe fn fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
|
||||
_mm256_sub_ps(_mm256_mul_ps(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
|
||||
_mm256_fmadd_pd(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
|
||||
_mm256_add_pd(_mm256_mul_pd(a, b), c)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "fma")]
|
||||
#[inline(always)]
|
||||
pub unsafe fn fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
|
||||
_mm256_fmsub_pd(a, b, c)
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "fma"))]
|
||||
pub unsafe fn fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
|
||||
_mm256_sub_pd(_mm256_mul_pd(a, b), c)
|
||||
}
|
23
src/kernel/avx/hsum.rs
Normal file
23
src/kernel/avx/hsum.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use super::intrinsics::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn hsum_ps(v: __m256) -> f32 {
|
||||
let qhigh = _mm256_extractf128_ps(v, 1);
|
||||
let qlow = _mm256_castps256_ps128(v);
|
||||
let qsum = _mm_add_ps(qhigh, qlow);
|
||||
let dhigh = _mm_movehl_ps(qsum, qsum);
|
||||
let dlow = qsum;
|
||||
let dsum = _mm_add_ps(dhigh, dlow);
|
||||
let high = _mm_shuffle_ps(dsum, dsum, 1);
|
||||
let low = dsum;
|
||||
_mm_cvtss_f32(_mm_add_ss(high, low))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn hsum_pd(v: __m256d) -> f64 {
|
||||
let vhigh = _mm256_extractf128_pd(v, 1);
|
||||
let vlow = _mm256_castpd256_pd128(v);
|
||||
let vsum = _mm_add_pd(vlow, vhigh);
|
||||
let h64 = _mm_unpackhi_pd(vsum, vsum);
|
||||
_mm_cvtsd_f64(_mm_add_sd(vsum, h64))
|
||||
}
|
5
src/kernel/avx/intrinsics.rs
Normal file
5
src/kernel/avx/intrinsics.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(target_arch = "x86")]
|
||||
pub use std::arch::x86::*;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub use std::arch::x86_64::*;
|
396
src/kernel/avx/l1d.rs
Normal file
396
src/kernel/avx/l1d.rs
Normal file
@ -0,0 +1,396 @@
|
||||
use super::fma::{fmadd_pd, fmsub_pd};
|
||||
use super::hsum::hsum_pd;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub unsafe fn drot(
|
||||
n: usize,
|
||||
mut x: *mut f64,
|
||||
incx: usize,
|
||||
mut y: *mut f64,
|
||||
incy: usize,
|
||||
c: f64,
|
||||
s: f64,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let c0 = _mm256_broadcast_sd(&c);
|
||||
let s0 = _mm256_broadcast_sd(&s);
|
||||
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0)));
|
||||
_mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0)));
|
||||
_mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1)));
|
||||
_mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1)));
|
||||
_mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2)));
|
||||
_mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2)));
|
||||
_mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3)));
|
||||
_mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3)));
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(x, y0);
|
||||
_mm256_storeu_pd(y, x0);
|
||||
_mm256_storeu_pd(x.add(4), y1);
|
||||
_mm256_storeu_pd(y.add(4), x1);
|
||||
_mm256_storeu_pd(x.add(8), y2);
|
||||
_mm256_storeu_pd(y.add(8), x2);
|
||||
_mm256_storeu_pd(x.add(12), y3);
|
||||
_mm256_storeu_pd(y.add(12), x3);
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) {
|
||||
if incx == 1 {
|
||||
let a0 = _mm256_broadcast_sd(&a);
|
||||
for _ in 0..n / 32 {
|
||||
let mut x0 = _mm256_loadu_pd(x);
|
||||
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
x0 = _mm256_mul_pd(a0, x0);
|
||||
x1 = _mm256_mul_pd(a0, x1);
|
||||
x2 = _mm256_mul_pd(a0, x2);
|
||||
x3 = _mm256_mul_pd(a0, x3);
|
||||
x4 = _mm256_mul_pd(a0, x4);
|
||||
x5 = _mm256_mul_pd(a0, x5);
|
||||
x6 = _mm256_mul_pd(a0, x6);
|
||||
x7 = _mm256_mul_pd(a0, x7);
|
||||
|
||||
_mm256_storeu_pd(x, x0);
|
||||
_mm256_storeu_pd(x.add(4), x1);
|
||||
_mm256_storeu_pd(x.add(8), x2);
|
||||
_mm256_storeu_pd(x.add(12), x3);
|
||||
_mm256_storeu_pd(x.add(16), x4);
|
||||
_mm256_storeu_pd(x.add(20), x5);
|
||||
_mm256_storeu_pd(x.add(24), x6);
|
||||
_mm256_storeu_pd(x.add(28), x7);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*x *= a;
|
||||
x = x.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*x *= a;
|
||||
x = x.add(incx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let x4 = _mm256_loadu_pd(x.add(16));
|
||||
let x5 = _mm256_loadu_pd(x.add(20));
|
||||
let x6 = _mm256_loadu_pd(x.add(24));
|
||||
let x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
_mm256_storeu_pd(y, x0);
|
||||
_mm256_storeu_pd(y.add(4), x1);
|
||||
_mm256_storeu_pd(y.add(8), x2);
|
||||
_mm256_storeu_pd(y.add(12), x3);
|
||||
_mm256_storeu_pd(y.add(16), x4);
|
||||
_mm256_storeu_pd(y.add(20), x5);
|
||||
_mm256_storeu_pd(y.add(24), x6);
|
||||
_mm256_storeu_pd(y.add(28), x7);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*y = *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y = *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn daxpy(
|
||||
n: usize,
|
||||
a: f64,
|
||||
mut x: *const f64,
|
||||
incx: usize,
|
||||
mut y: *mut f64,
|
||||
incy: usize,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let a0 = _mm256_broadcast_sd(&a);
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
_mm256_storeu_pd(y, fmadd_pd(a0, x0, y0));
|
||||
_mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1));
|
||||
_mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2));
|
||||
_mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3));
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
for _ in 0..n % 16 {
|
||||
*y += a * *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y += a * *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn ddot(
|
||||
n: usize,
|
||||
mut x: *const f64,
|
||||
incx: usize,
|
||||
mut y: *const f64,
|
||||
incy: usize,
|
||||
) -> f64 {
|
||||
if incx == 1 && incy == 1 {
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let y0 = _mm256_loadu_pd(y);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let y1 = _mm256_loadu_pd(y.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let y2 = _mm256_loadu_pd(y.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
let y3 = _mm256_loadu_pd(y.add(12));
|
||||
|
||||
acc0 = fmadd_pd(x0, y0, acc0);
|
||||
acc1 = fmadd_pd(x1, y1, acc1);
|
||||
acc2 = fmadd_pd(x2, y2, acc2);
|
||||
acc3 = fmadd_pd(x3, y3, acc3);
|
||||
|
||||
x = x.add(16);
|
||||
y = y.add(16);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 16 {
|
||||
acc += *x * *y;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += *x * *y;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||
if incx == 1 {
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 16 {
|
||||
let x0 = _mm256_loadu_pd(x);
|
||||
let x1 = _mm256_loadu_pd(x.add(4));
|
||||
let x2 = _mm256_loadu_pd(x.add(8));
|
||||
let x3 = _mm256_loadu_pd(x.add(12));
|
||||
|
||||
acc0 = fmadd_pd(x0, x0, acc0);
|
||||
acc1 = fmadd_pd(x1, x1, acc1);
|
||||
acc2 = fmadd_pd(x2, x2, acc2);
|
||||
acc3 = fmadd_pd(x3, x3, acc3);
|
||||
|
||||
x = x.add(16);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 16 {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(1);
|
||||
}
|
||||
acc.sqrt()
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 {
|
||||
if incx == 1 {
|
||||
let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF));
|
||||
|
||||
let mut acc0 = _mm256_setzero_pd();
|
||||
let mut acc1 = _mm256_setzero_pd();
|
||||
let mut acc2 = _mm256_setzero_pd();
|
||||
let mut acc3 = _mm256_setzero_pd();
|
||||
let mut acc4 = _mm256_setzero_pd();
|
||||
let mut acc5 = _mm256_setzero_pd();
|
||||
let mut acc6 = _mm256_setzero_pd();
|
||||
let mut acc7 = _mm256_setzero_pd();
|
||||
for _ in 0..n / 32 {
|
||||
let mut x0 = _mm256_loadu_pd(x);
|
||||
let mut x1 = _mm256_loadu_pd(x.add(4));
|
||||
let mut x2 = _mm256_loadu_pd(x.add(8));
|
||||
let mut x3 = _mm256_loadu_pd(x.add(12));
|
||||
let mut x4 = _mm256_loadu_pd(x.add(16));
|
||||
let mut x5 = _mm256_loadu_pd(x.add(20));
|
||||
let mut x6 = _mm256_loadu_pd(x.add(24));
|
||||
let mut x7 = _mm256_loadu_pd(x.add(28));
|
||||
|
||||
x0 = _mm256_and_pd(mask, x0);
|
||||
x1 = _mm256_and_pd(mask, x1);
|
||||
x2 = _mm256_and_pd(mask, x2);
|
||||
x3 = _mm256_and_pd(mask, x3);
|
||||
x4 = _mm256_and_pd(mask, x4);
|
||||
x5 = _mm256_and_pd(mask, x5);
|
||||
x6 = _mm256_and_pd(mask, x6);
|
||||
x7 = _mm256_and_pd(mask, x7);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, x0);
|
||||
acc1 = _mm256_add_pd(acc1, x1);
|
||||
acc2 = _mm256_add_pd(acc2, x2);
|
||||
acc3 = _mm256_add_pd(acc3, x3);
|
||||
acc4 = _mm256_add_pd(acc4, x4);
|
||||
acc5 = _mm256_add_pd(acc5, x5);
|
||||
acc6 = _mm256_add_pd(acc6, x6);
|
||||
acc7 = _mm256_add_pd(acc7, x7);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_pd(acc0, acc1);
|
||||
acc2 = _mm256_add_pd(acc2, acc3);
|
||||
acc4 = _mm256_add_pd(acc4, acc5);
|
||||
acc6 = _mm256_add_pd(acc6, acc7);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, acc2);
|
||||
acc4 = _mm256_add_pd(acc4, acc6);
|
||||
|
||||
acc0 = _mm256_add_pd(acc0, acc4);
|
||||
|
||||
let mut acc = hsum_pd(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
acc += (*x).abs();
|
||||
x = x.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += (*x).abs();
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
396
src/kernel/avx/l1s.rs
Normal file
396
src/kernel/avx/l1s.rs
Normal file
@ -0,0 +1,396 @@
|
||||
use super::fma::{fmadd_ps, fmsub_ps};
|
||||
use super::hsum::hsum_ps;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub unsafe fn srot(
|
||||
n: usize,
|
||||
mut x: *mut f32,
|
||||
incx: usize,
|
||||
mut y: *mut f32,
|
||||
incy: usize,
|
||||
c: f32,
|
||||
s: f32,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let c0 = _mm256_broadcast_ss(&c);
|
||||
let s0 = _mm256_broadcast_ss(&s);
|
||||
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0)));
|
||||
_mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0)));
|
||||
_mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1)));
|
||||
_mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1)));
|
||||
_mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2)));
|
||||
_mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2)));
|
||||
_mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3)));
|
||||
_mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3)));
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
let y0 = *y;
|
||||
|
||||
*x = c * x0 + s * y0;
|
||||
*y = c * y0 - s * x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(x, y0);
|
||||
_mm256_storeu_ps(y, x0);
|
||||
_mm256_storeu_ps(x.add(8), y1);
|
||||
_mm256_storeu_ps(y.add(8), x1);
|
||||
_mm256_storeu_ps(x.add(16), y2);
|
||||
_mm256_storeu_ps(y.add(16), x2);
|
||||
_mm256_storeu_ps(x.add(24), y3);
|
||||
_mm256_storeu_ps(y.add(24), x3);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
|
||||
*x = *y;
|
||||
*y = x0;
|
||||
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) {
|
||||
if incx == 1 {
|
||||
let a0 = _mm256_broadcast_ss(&a);
|
||||
for _ in 0..n / 64 {
|
||||
let mut x0 = _mm256_loadu_ps(x);
|
||||
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
x0 = _mm256_mul_ps(a0, x0);
|
||||
x1 = _mm256_mul_ps(a0, x1);
|
||||
x2 = _mm256_mul_ps(a0, x2);
|
||||
x3 = _mm256_mul_ps(a0, x3);
|
||||
x4 = _mm256_mul_ps(a0, x4);
|
||||
x5 = _mm256_mul_ps(a0, x5);
|
||||
x6 = _mm256_mul_ps(a0, x6);
|
||||
x7 = _mm256_mul_ps(a0, x7);
|
||||
|
||||
_mm256_storeu_ps(x, x0);
|
||||
_mm256_storeu_ps(x.add(8), x1);
|
||||
_mm256_storeu_ps(x.add(16), x2);
|
||||
_mm256_storeu_ps(x.add(24), x3);
|
||||
_mm256_storeu_ps(x.add(32), x4);
|
||||
_mm256_storeu_ps(x.add(40), x5);
|
||||
_mm256_storeu_ps(x.add(48), x6);
|
||||
_mm256_storeu_ps(x.add(56), x7);
|
||||
|
||||
x = x.add(64);
|
||||
}
|
||||
for _ in 0..n % 64 {
|
||||
*x *= a;
|
||||
x = x.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*x *= a;
|
||||
x = x.add(incx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) {
|
||||
if incx == 1 && incy == 1 {
|
||||
for _ in 0..n / 64 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let x4 = _mm256_loadu_ps(x.add(32));
|
||||
let x5 = _mm256_loadu_ps(x.add(40));
|
||||
let x6 = _mm256_loadu_ps(x.add(48));
|
||||
let x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
_mm256_storeu_ps(y, x0);
|
||||
_mm256_storeu_ps(y.add(8), x1);
|
||||
_mm256_storeu_ps(y.add(16), x2);
|
||||
_mm256_storeu_ps(y.add(24), x3);
|
||||
_mm256_storeu_ps(y.add(32), x4);
|
||||
_mm256_storeu_ps(y.add(40), x5);
|
||||
_mm256_storeu_ps(y.add(48), x6);
|
||||
_mm256_storeu_ps(y.add(56), x7);
|
||||
|
||||
x = x.add(64);
|
||||
y = y.add(64);
|
||||
}
|
||||
for _ in 0..n % 64 {
|
||||
*y = *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y = *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn saxpy(
|
||||
n: usize,
|
||||
a: f32,
|
||||
mut x: *const f32,
|
||||
incx: usize,
|
||||
mut y: *mut f32,
|
||||
incy: usize,
|
||||
) {
|
||||
if incx == 1 && incy == 1 {
|
||||
let a0 = _mm256_broadcast_ss(&a);
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
_mm256_storeu_ps(y, fmadd_ps(a0, x0, y0));
|
||||
_mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1));
|
||||
_mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2));
|
||||
_mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3));
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
for _ in 0..n % 32 {
|
||||
*y += a * *x;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
} else {
|
||||
for _ in 0..n {
|
||||
*y += a * *x;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sdot(
|
||||
n: usize,
|
||||
mut x: *const f32,
|
||||
incx: usize,
|
||||
mut y: *const f32,
|
||||
incy: usize,
|
||||
) -> f32 {
|
||||
if incx == 1 && incy == 1 {
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let y0 = _mm256_loadu_ps(y);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let y1 = _mm256_loadu_ps(y.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let y2 = _mm256_loadu_ps(y.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
let y3 = _mm256_loadu_ps(y.add(24));
|
||||
|
||||
acc0 = fmadd_ps(x0, y0, acc0);
|
||||
acc1 = fmadd_ps(x1, y1, acc1);
|
||||
acc2 = fmadd_ps(x2, y2, acc2);
|
||||
acc3 = fmadd_ps(x3, y3, acc3);
|
||||
|
||||
x = x.add(32);
|
||||
y = y.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
acc += *x * *y;
|
||||
x = x.add(1);
|
||||
y = y.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += *x * *y;
|
||||
x = x.add(incx);
|
||||
y = y.add(incy);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||
if incx == 1 {
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 32 {
|
||||
let x0 = _mm256_loadu_ps(x);
|
||||
let x1 = _mm256_loadu_ps(x.add(8));
|
||||
let x2 = _mm256_loadu_ps(x.add(16));
|
||||
let x3 = _mm256_loadu_ps(x.add(24));
|
||||
|
||||
acc0 = fmadd_ps(x0, x0, acc0);
|
||||
acc1 = fmadd_ps(x1, x1, acc1);
|
||||
acc2 = fmadd_ps(x2, x2, acc2);
|
||||
acc3 = fmadd_ps(x3, x3, acc3);
|
||||
|
||||
x = x.add(32);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 32 {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(1);
|
||||
}
|
||||
acc.sqrt()
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
let x0 = *x;
|
||||
acc += x0 * x0;
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 {
|
||||
if incx == 1 {
|
||||
let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF));
|
||||
|
||||
let mut acc0 = _mm256_setzero_ps();
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
let mut acc4 = _mm256_setzero_ps();
|
||||
let mut acc5 = _mm256_setzero_ps();
|
||||
let mut acc6 = _mm256_setzero_ps();
|
||||
let mut acc7 = _mm256_setzero_ps();
|
||||
for _ in 0..n / 64 {
|
||||
let mut x0 = _mm256_loadu_ps(x);
|
||||
let mut x1 = _mm256_loadu_ps(x.add(8));
|
||||
let mut x2 = _mm256_loadu_ps(x.add(16));
|
||||
let mut x3 = _mm256_loadu_ps(x.add(24));
|
||||
let mut x4 = _mm256_loadu_ps(x.add(32));
|
||||
let mut x5 = _mm256_loadu_ps(x.add(40));
|
||||
let mut x6 = _mm256_loadu_ps(x.add(48));
|
||||
let mut x7 = _mm256_loadu_ps(x.add(56));
|
||||
|
||||
x0 = _mm256_and_ps(mask, x0);
|
||||
x1 = _mm256_and_ps(mask, x1);
|
||||
x2 = _mm256_and_ps(mask, x2);
|
||||
x3 = _mm256_and_ps(mask, x3);
|
||||
x4 = _mm256_and_ps(mask, x4);
|
||||
x5 = _mm256_and_ps(mask, x5);
|
||||
x6 = _mm256_and_ps(mask, x6);
|
||||
x7 = _mm256_and_ps(mask, x7);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, x0);
|
||||
acc1 = _mm256_add_ps(acc1, x1);
|
||||
acc2 = _mm256_add_ps(acc2, x2);
|
||||
acc3 = _mm256_add_ps(acc3, x3);
|
||||
acc4 = _mm256_add_ps(acc4, x4);
|
||||
acc5 = _mm256_add_ps(acc5, x5);
|
||||
acc6 = _mm256_add_ps(acc6, x6);
|
||||
acc7 = _mm256_add_ps(acc7, x7);
|
||||
|
||||
x = x.add(64);
|
||||
}
|
||||
acc0 = _mm256_add_ps(acc0, acc1);
|
||||
acc2 = _mm256_add_ps(acc2, acc3);
|
||||
acc4 = _mm256_add_ps(acc4, acc5);
|
||||
acc6 = _mm256_add_ps(acc6, acc7);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, acc2);
|
||||
acc4 = _mm256_add_ps(acc4, acc6);
|
||||
|
||||
acc0 = _mm256_add_ps(acc0, acc4);
|
||||
|
||||
let mut acc = hsum_ps(acc0);
|
||||
for _ in 0..n % 64 {
|
||||
acc += (*x).abs();
|
||||
x = x.add(1);
|
||||
}
|
||||
acc
|
||||
} else {
|
||||
let mut acc = 0.0;
|
||||
for _ in 0..n {
|
||||
acc += (*x).abs();
|
||||
x = x.add(incx);
|
||||
}
|
||||
acc
|
||||
}
|
||||
}
|
140
src/kernel/avx/l3d.rs
Normal file
140
src/kernel/avx/l3d.rs
Normal file
@ -0,0 +1,140 @@
|
||||
use super::fma::fmadd_pd;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub(crate) unsafe fn dgemm_ukr_8x4(
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
pa: *const f64,
|
||||
pb: *const f64,
|
||||
beta: f64,
|
||||
c: *mut f64,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut mt00 = _mm256_setzero_pd();
|
||||
let mut mt01 = _mm256_setzero_pd();
|
||||
let mut mt02 = _mm256_setzero_pd();
|
||||
let mut mt03 = _mm256_setzero_pd();
|
||||
let mut mt10 = _mm256_setzero_pd();
|
||||
let mut mt11 = _mm256_setzero_pd();
|
||||
let mut mt12 = _mm256_setzero_pd();
|
||||
let mut mt13 = _mm256_setzero_pd();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_loadu_pd(pa);
|
||||
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||
|
||||
let b0 = _mm256_broadcast_sd(&*pb);
|
||||
let b1 = _mm256_broadcast_sd(&*pb.add(1));
|
||||
let b2 = _mm256_broadcast_sd(&*pb.add(2));
|
||||
let b3 = _mm256_broadcast_sd(&*pb.add(3));
|
||||
|
||||
mt00 = fmadd_pd(a0, b0, mt00);
|
||||
mt01 = fmadd_pd(a0, b1, mt01);
|
||||
mt02 = fmadd_pd(a0, b2, mt02);
|
||||
mt03 = fmadd_pd(a0, b3, mt03);
|
||||
mt10 = fmadd_pd(a1, b0, mt10);
|
||||
mt11 = fmadd_pd(a1, b1, mt11);
|
||||
mt12 = fmadd_pd(a1, b2, mt12);
|
||||
mt13 = fmadd_pd(a1, b3, mt13);
|
||||
|
||||
pa = pa.add(8);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_sd(&alpha);
|
||||
|
||||
mt00 = _mm256_mul_pd(alpha, mt00);
|
||||
mt01 = _mm256_mul_pd(alpha, mt01);
|
||||
mt02 = _mm256_mul_pd(alpha, mt02);
|
||||
mt03 = _mm256_mul_pd(alpha, mt03);
|
||||
mt10 = _mm256_mul_pd(alpha, mt10);
|
||||
mt11 = _mm256_mul_pd(alpha, mt11);
|
||||
mt12 = _mm256_mul_pd(alpha, mt12);
|
||||
mt13 = _mm256_mul_pd(alpha, mt13);
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_sd(&beta);
|
||||
|
||||
mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00);
|
||||
mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01);
|
||||
mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02);
|
||||
mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03);
|
||||
mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10);
|
||||
mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11);
|
||||
mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12);
|
||||
mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(ccol0, mt00);
|
||||
_mm256_storeu_pd(ccol1, mt01);
|
||||
_mm256_storeu_pd(ccol2, mt02);
|
||||
_mm256_storeu_pd(ccol3, mt03);
|
||||
_mm256_storeu_pd(ccol0.add(4), mt10);
|
||||
_mm256_storeu_pd(ccol1.add(4), mt11);
|
||||
_mm256_storeu_pd(ccol2.add(4), mt12);
|
||||
_mm256_storeu_pd(ccol3.add(4), mt13);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn dgemm_sup_8x1(
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
pa: *const f64,
|
||||
b: *const f64,
|
||||
beta: f64,
|
||||
c: *mut f64,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_pd();
|
||||
let mut mt1 = _mm256_setzero_pd();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_loadu_pd(pa);
|
||||
let a1 = _mm256_loadu_pd(pa.add(4));
|
||||
|
||||
let b0 = _mm256_broadcast_sd(&*b);
|
||||
|
||||
mt0 = fmadd_pd(a0, b0, mt0);
|
||||
mt1 = fmadd_pd(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(8);
|
||||
b = b.add(1);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_sd(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_pd(alpha, mt0);
|
||||
mt1 = _mm256_mul_pd(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_sd(&beta);
|
||||
|
||||
mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0);
|
||||
mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(c, mt0);
|
||||
_mm256_storeu_pd(c.add(4), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) {
|
||||
let mut a = a;
|
||||
let mut pa = pa;
|
||||
|
||||
for _ in 0..k {
|
||||
_mm256_storeu_pd(pa, _mm256_loadu_pd(a));
|
||||
_mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4)));
|
||||
|
||||
pa = pa.add(8);
|
||||
a = a.add(lda);
|
||||
}
|
||||
}
|
185
src/kernel/avx/l3s.rs
Normal file
185
src/kernel/avx/l3s.rs
Normal file
@ -0,0 +1,185 @@
|
||||
use super::fma::fmadd_ps;
|
||||
use super::intrinsics::*;
|
||||
|
||||
pub(crate) unsafe fn sgemm_ukr_16x4(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut mt00 = _mm256_setzero_ps();
|
||||
let mut mt01 = _mm256_setzero_ps();
|
||||
let mut mt02 = _mm256_setzero_ps();
|
||||
let mut mt03 = _mm256_setzero_ps();
|
||||
|
||||
let mut mt10 = _mm256_setzero_ps();
|
||||
let mut mt11 = _mm256_setzero_ps();
|
||||
let mut mt12 = _mm256_setzero_ps();
|
||||
let mut mt13 = _mm256_setzero_ps();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*pb);
|
||||
let b1 = _mm256_broadcast_ss(&*pb.add(1));
|
||||
let b2 = _mm256_broadcast_ss(&*pb.add(2));
|
||||
let b3 = _mm256_broadcast_ss(&*pb.add(3));
|
||||
|
||||
mt00 = fmadd_ps(a0, b0, mt00);
|
||||
mt01 = fmadd_ps(a0, b1, mt01);
|
||||
mt02 = fmadd_ps(a0, b2, mt02);
|
||||
mt03 = fmadd_ps(a0, b3, mt03);
|
||||
mt10 = fmadd_ps(a1, b0, mt10);
|
||||
mt11 = fmadd_ps(a1, b1, mt11);
|
||||
mt12 = fmadd_ps(a1, b2, mt12);
|
||||
mt13 = fmadd_ps(a1, b3, mt13);
|
||||
|
||||
pa = pa.add(16);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt00 = _mm256_mul_ps(alpha, mt00);
|
||||
mt01 = _mm256_mul_ps(alpha, mt01);
|
||||
mt02 = _mm256_mul_ps(alpha, mt02);
|
||||
mt03 = _mm256_mul_ps(alpha, mt03);
|
||||
mt10 = _mm256_mul_ps(alpha, mt10);
|
||||
mt11 = _mm256_mul_ps(alpha, mt11);
|
||||
mt12 = _mm256_mul_ps(alpha, mt12);
|
||||
mt13 = _mm256_mul_ps(alpha, mt13);
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
|
||||
mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
|
||||
mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
|
||||
mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
|
||||
mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
|
||||
mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
|
||||
mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
|
||||
mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(ccol0, mt00);
|
||||
_mm256_storeu_ps(ccol1, mt01);
|
||||
_mm256_storeu_ps(ccol2, mt02);
|
||||
_mm256_storeu_ps(ccol3, mt03);
|
||||
_mm256_storeu_ps(ccol0.add(8), mt10);
|
||||
_mm256_storeu_ps(ccol1.add(8), mt11);
|
||||
_mm256_storeu_ps(ccol2.add(8), mt12);
|
||||
_mm256_storeu_ps(ccol3.add(8), mt13);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup_16x1(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
b: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_ps();
|
||||
let mut mt1 = _mm256_setzero_ps();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*b);
|
||||
|
||||
mt0 = fmadd_ps(a0, b0, mt0);
|
||||
mt1 = fmadd_ps(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(16);
|
||||
b = b.add(1);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(c, mt0);
|
||||
_mm256_storeu_ps(c.add(8), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup0_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
pa: *const f32,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
) {
|
||||
let mut mt0 = _mm256_setzero_ps();
|
||||
let mut mt1 = _mm256_setzero_ps();
|
||||
|
||||
let mut pa = pa;
|
||||
let mut b = b;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = _mm256_load_ps(pa);
|
||||
let a1 = _mm256_load_ps(pa.add(8));
|
||||
|
||||
let b0 = _mm256_broadcast_ss(&*b);
|
||||
|
||||
mt0 = fmadd_ps(a0, b0, mt0);
|
||||
mt1 = fmadd_ps(a1, b0, mt1);
|
||||
|
||||
pa = pa.add(16);
|
||||
b = b.add(ldb);
|
||||
}
|
||||
|
||||
let alpha = _mm256_broadcast_ss(&alpha);
|
||||
|
||||
mt0 = _mm256_mul_ps(alpha, mt0);
|
||||
mt1 = _mm256_mul_ps(alpha, mt1);
|
||||
|
||||
if beta != 0.0 {
|
||||
let beta = _mm256_broadcast_ss(&beta);
|
||||
|
||||
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
|
||||
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(c, mt0);
|
||||
_mm256_storeu_ps(c.add(8), mt1);
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
let mut a = a;
|
||||
let mut pa = pa;
|
||||
|
||||
for _ in 0..k {
|
||||
_mm256_store_ps(pa, _mm256_loadu_ps(a));
|
||||
_mm256_store_ps(pa.add(8), _mm256_loadu_ps(a.add(8)));
|
||||
|
||||
pa = pa.add(16);
|
||||
a = a.add(lda);
|
||||
}
|
||||
}
|
7
src/kernel/avx/mod.rs
Normal file
7
src/kernel/avx/mod.rs
Normal file
@ -0,0 +1,7 @@
|
||||
mod fma;
|
||||
mod hsum;
|
||||
mod intrinsics;
|
||||
pub mod l1d;
|
||||
pub mod l1s;
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
74
src/kernel/generic/l3d.rs
Normal file
74
src/kernel/generic/l3d.rs
Normal file
@ -0,0 +1,74 @@
|
||||
pub(crate) unsafe fn dgemm_sup_1x4(
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
a: *const f64,
|
||||
lda: usize,
|
||||
pb: *const f64,
|
||||
beta: f64,
|
||||
c: *mut f64,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut c0 = 0.0;
|
||||
let mut c1 = 0.0;
|
||||
let mut c2 = 0.0;
|
||||
let mut c3 = 0.0;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
|
||||
c0 += *pb * a0;
|
||||
c1 += *pb.add(1) * a0;
|
||||
c2 += *pb.add(2) * a0;
|
||||
c3 += *pb.add(3) * a0;
|
||||
|
||||
a = a.add(lda);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
c0 *= alpha;
|
||||
c1 *= alpha;
|
||||
c2 *= alpha;
|
||||
c3 *= alpha;
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
c0 *= beta * *ccol0;
|
||||
c1 *= beta * *ccol1;
|
||||
c2 *= beta * *ccol2;
|
||||
c3 *= beta * *ccol3;
|
||||
}
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn dgemm_pb_x4(k: usize, b: *const f64, ldb: usize, pb: *mut f64) {
|
||||
let mut bcol0 = b;
|
||||
let mut bcol1 = b.add(ldb);
|
||||
let mut bcol2 = b.add(ldb * 2);
|
||||
let mut bcol3 = b.add(ldb * 3);
|
||||
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
*pb = *bcol0;
|
||||
*pb.add(1) = *bcol1;
|
||||
*pb.add(2) = *bcol2;
|
||||
*pb.add(3) = *bcol3;
|
||||
|
||||
bcol0 = bcol0.add(1);
|
||||
bcol1 = bcol1.add(1);
|
||||
bcol2 = bcol2.add(1);
|
||||
bcol3 = bcol3.add(1);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
}
|
169
src/kernel/generic/l3s.rs
Normal file
169
src/kernel/generic/l3s.rs
Normal file
@ -0,0 +1,169 @@
|
||||
pub(crate) unsafe fn sgemm_sup_1x4(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut c0 = 0.0;
|
||||
let mut c1 = 0.0;
|
||||
let mut c2 = 0.0;
|
||||
let mut c3 = 0.0;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
|
||||
c0 += *pb * a0;
|
||||
c1 += *pb.add(1) * a0;
|
||||
c2 += *pb.add(2) * a0;
|
||||
c3 += *pb.add(3) * a0;
|
||||
|
||||
a = a.add(lda);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
c0 *= alpha;
|
||||
c1 *= alpha;
|
||||
c2 *= alpha;
|
||||
c3 *= alpha;
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
c0 += beta * *ccol0;
|
||||
c1 += beta * *ccol1;
|
||||
c2 += beta * *ccol2;
|
||||
c3 += beta * *ccol3;
|
||||
}
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_sup1_t(
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
pb: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
let mut c0 = 0.0;
|
||||
let mut c1 = 0.0;
|
||||
let mut c2 = 0.0;
|
||||
let mut c3 = 0.0;
|
||||
|
||||
let mut a = a;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
let a0 = *a;
|
||||
|
||||
c0 += *pb * a0;
|
||||
c1 += *pb.add(1) * a0;
|
||||
c2 += *pb.add(2) * a0;
|
||||
c3 += *pb.add(3) * a0;
|
||||
|
||||
a = a.add(1);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
|
||||
c0 *= alpha;
|
||||
c1 *= alpha;
|
||||
c2 *= alpha;
|
||||
c3 *= alpha;
|
||||
|
||||
let ccol0 = c;
|
||||
let ccol1 = c.add(ldc);
|
||||
let ccol2 = c.add(ldc * 2);
|
||||
let ccol3 = c.add(ldc * 3);
|
||||
|
||||
if beta != 0.0 {
|
||||
c0 += beta * *ccol0;
|
||||
c1 += beta * *ccol1;
|
||||
c2 += beta * *ccol2;
|
||||
c3 += beta * *ccol3;
|
||||
}
|
||||
|
||||
*ccol0 = c0;
|
||||
*ccol1 = c1;
|
||||
*ccol2 = c2;
|
||||
*ccol3 = c3;
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
let mut bcol0 = b;
|
||||
let mut bcol1 = b.add(ldb);
|
||||
let mut bcol2 = b.add(ldb * 2);
|
||||
let mut bcol3 = b.add(ldb * 3);
|
||||
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
*pb = *bcol0;
|
||||
*pb.add(1) = *bcol1;
|
||||
*pb.add(2) = *bcol2;
|
||||
*pb.add(3) = *bcol3;
|
||||
|
||||
bcol0 = bcol0.add(1);
|
||||
bcol1 = bcol1.add(1);
|
||||
bcol2 = bcol2.add(1);
|
||||
bcol3 = bcol3.add(1);
|
||||
pb = pb.add(4);
|
||||
}
|
||||
}
|
||||
|
||||
// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
// use crate::kernel::params::single::MR;
|
||||
// let mut a = a;
|
||||
// let mut pa = pa;
|
||||
|
||||
// for p in 0..k {
|
||||
// for j in 0..MR {
|
||||
// *pa.add(j * k + p) = *a.add(p * lda + j);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
|
||||
use crate::kernel::params::single::NR;
|
||||
let mut b = b;
|
||||
let mut pb = pb;
|
||||
|
||||
for _ in 0..k {
|
||||
for j in 0..NR {
|
||||
*pb.add(j) = *b.add(j);
|
||||
}
|
||||
|
||||
pb = pb.add(NR);
|
||||
b = b.add(ldb);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn sgemm_pa_t(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
|
||||
use crate::kernel::params::single::MR;
|
||||
let mut a = a;
|
||||
let mut pa = pa;
|
||||
|
||||
for _ in 0..k {
|
||||
for j in 0..MR {
|
||||
*pa.add(j) = *a.add(j * lda);
|
||||
}
|
||||
|
||||
a = a.add(1);
|
||||
pa = pa.add(MR);
|
||||
}
|
||||
}
|
2
src/kernel/generic/mod.rs
Normal file
2
src/kernel/generic/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod l3d;
|
||||
pub mod l3s;
|
60
src/kernel/mod.rs
Normal file
60
src/kernel/mod.rs
Normal file
@ -0,0 +1,60 @@
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
mod avx;
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub use avx::{l1d::*, l1s::*};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use avx::l3s::{
|
||||
sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t,
|
||||
};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use avx::l3d::{
|
||||
dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr,
|
||||
};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t};
|
||||
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86_64", target_arch = "x86"),
|
||||
target_feature = "avx"
|
||||
))]
|
||||
pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1};
|
||||
|
||||
|
||||
mod generic;
|
||||
|
||||
|
||||
pub mod params {
|
||||
pub mod single {
|
||||
pub const MC: usize = 256;
|
||||
pub const KC: usize = 128;
|
||||
pub const NC: usize = 1024;
|
||||
pub const MR: usize = 16;
|
||||
pub const NR: usize = 4;
|
||||
}
|
||||
pub mod double {
|
||||
pub const MC: usize = 256;
|
||||
pub const KC: usize = 512;
|
||||
pub const NC: usize = 4096;
|
||||
pub const MR: usize = 8;
|
||||
pub const NR: usize = 4;
|
||||
}
|
||||
}
|
498
src/main.rs
Normal file
498
src/main.rs
Normal file
@ -0,0 +1,498 @@
|
||||
mod matrix;
|
||||
mod sgemm;
|
||||
mod aligned_alloc;
|
||||
mod kernel;
|
||||
|
||||
extern crate blas;
|
||||
extern crate openblas;
|
||||
|
||||
use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix};
|
||||
|
||||
pub unsafe fn sgemm(
|
||||
transa: bool,
|
||||
transb: bool,
|
||||
transc: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
match (transa, transb, transc) {
|
||||
(false, false, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, false, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(false, true, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
(true, false, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, false) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutMatrix::new(c, ldc)),
|
||||
|
||||
|
||||
(true, true, true) => sgemm::sgemm(
|
||||
m, n, k, alpha,
|
||||
ConstTransposedMatrix::new(a, lda),
|
||||
ConstTransposedMatrix::new(b, ldb),
|
||||
beta,
|
||||
MutTransposedMatrix::new(c, ldc)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
unsafe fn sgemm_ref_nn(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
for j in 0..n {
|
||||
for i in 0..m {
|
||||
let mut ci = *c.add(i + j * ldc);
|
||||
for p in 0..k {
|
||||
ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
|
||||
}
|
||||
*c.add(i + j * ldc) = ci;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn sgemm_ref_nt(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
_alpha: f32,
|
||||
a: *const f32,
|
||||
lda: usize,
|
||||
b: *const f32,
|
||||
ldb: usize,
|
||||
_beta: f32,
|
||||
c: *mut f32,
|
||||
ldc: usize,
|
||||
) {
|
||||
for j in 0..n {
|
||||
for i in 0..m {
|
||||
let mut ci = *c.add(i + j * ldc);
|
||||
|
||||
for p in 0..k {
|
||||
ci += *a.add(i + p * lda) * *b.add(j + p * ldb);
|
||||
}
|
||||
|
||||
*c.add(i + j * ldc) = ci;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
_bata: f32,
|
||||
c: &mut [f32], ldc: usize) {
|
||||
|
||||
let mk = m * k;
|
||||
let nk = n * k;
|
||||
let mn = m * n;
|
||||
let a = &a[0..mk];
|
||||
let b = &b[0..nk];
|
||||
let c = &mut c[0..mn];
|
||||
|
||||
for i_m in 0..m {
|
||||
for i_k in 0..k {
|
||||
let a_part = alpha * a[i_m * lda + i_k];
|
||||
|
||||
for i_n in 0..n {
|
||||
c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn main2() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const LEN: usize = 511;
|
||||
let (m, n, k) = (LEN, LEN, LEN);
|
||||
|
||||
let mut a = vec![0.5; m * k];
|
||||
let mut a_t = vec![0.5; m * k];
|
||||
let mut b = vec![0.5; n * k];
|
||||
let mut b_t = vec![0.5; n * k];
|
||||
let mut c_nn = vec![0.0; m * n];
|
||||
let mut c_nt = vec![0.0; m * n];
|
||||
let mut c_tn = vec![0.0; m * n];
|
||||
let mut c_tt = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
for i in 0..m {
|
||||
for j in 0..k {
|
||||
let v = rng.gen();
|
||||
a[i + j * m] = v;
|
||||
a_t[j + i * m] = v;
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..k {
|
||||
let v = rng.gen();
|
||||
b[i + j * n] = v;
|
||||
b_t[j + i * n] = v;
|
||||
}
|
||||
}
|
||||
|
||||
// let time = std::time::Instant::now();
|
||||
// unsafe {
|
||||
// gemm_nn(
|
||||
// LEN,
|
||||
// LEN,
|
||||
// LEN,
|
||||
// 1.0,
|
||||
// a.as_slice(),
|
||||
// LEN,
|
||||
// b.as_slice(),
|
||||
// LEN,
|
||||
// 1.0,
|
||||
// cref1.as_mut_slice(),
|
||||
// LEN,
|
||||
// )
|
||||
// }
|
||||
|
||||
// println!("Naive (mine) {}", time.elapsed().as_millis());
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm_ref_nn(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
cref.as_mut_ptr(),
|
||||
m,
|
||||
)
|
||||
}
|
||||
|
||||
println!("Naive {}", time.elapsed().as_millis());
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
c_nn.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
println!("[NN] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
|
||||
1.0,
|
||||
c_nt.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
println!("[NT] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
c_tn.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
println!("[TN] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a_t.as_ptr(),
|
||||
k,
|
||||
|
||||
b_t.as_ptr(),
|
||||
n,
|
||||
|
||||
1.0,
|
||||
c_tt.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
println!("[TT] Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
|
||||
for i in 0..LEN {
|
||||
for j in 0..LEN {
|
||||
let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]);
|
||||
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn main() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const LEN: usize = 8192;
|
||||
let (m, n, k) = (LEN, LEN, LEN);
|
||||
|
||||
let mut a = vec![0.5; m * k];
|
||||
let mut b = vec![0.5; n * k];
|
||||
let mut c = vec![0.0; m * n];
|
||||
let mut cref1 = vec![0.0; m * n];
|
||||
let mut cref = vec![0.0; m * n];
|
||||
|
||||
for i in 0..m {
|
||||
for j in 0..k {
|
||||
a[i + j * m] = rng.gen();
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..k {
|
||||
b[i + j * n] = rng.gen();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
// blas::sgemm(
|
||||
// b'N',
|
||||
// b'N',
|
||||
// m as i32,
|
||||
// n as i32,
|
||||
// k as i32,
|
||||
// 1.0,
|
||||
// a.as_slice(),
|
||||
// m as i32,
|
||||
// b.as_slice(),
|
||||
// k as i32,
|
||||
// 1.0,
|
||||
// cref1.as_mut_slice(),
|
||||
// m as i32,
|
||||
// );
|
||||
// gemm_nn(
|
||||
// m,
|
||||
// n,
|
||||
// k,
|
||||
// 1.0,
|
||||
// a.as_slice(),
|
||||
// m,
|
||||
// b.as_slice(),
|
||||
// m,
|
||||
// 1.0,
|
||||
// cref1.as_mut_slice(),
|
||||
// m,
|
||||
// )
|
||||
}
|
||||
println!("Matrixmultiply (mine) {}", time.elapsed().as_millis());
|
||||
// println!("Naive (mine) {}", time.elapsed().as_millis());
|
||||
|
||||
// let time = std::time::Instant::now();
|
||||
// unsafe {
|
||||
// sgemm_ref_nn(
|
||||
// m,
|
||||
// n,
|
||||
// k,
|
||||
|
||||
// 1.0,
|
||||
// a.as_ptr(),
|
||||
// m,
|
||||
|
||||
// b.as_ptr(),
|
||||
// k,
|
||||
|
||||
// 1.0,
|
||||
// c.as_mut_ptr(),
|
||||
// m,
|
||||
// )
|
||||
// }
|
||||
|
||||
// println!("Naive {}", time.elapsed().as_millis());
|
||||
|
||||
let time = std::time::Instant::now();
|
||||
unsafe {
|
||||
sgemm(
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
||||
1.0,
|
||||
a.as_ptr(),
|
||||
m,
|
||||
|
||||
b.as_ptr(),
|
||||
k,
|
||||
|
||||
1.0,
|
||||
cref.as_mut_ptr(),
|
||||
m,
|
||||
);
|
||||
}
|
||||
|
||||
println!("Optimized {}", time.elapsed().as_millis());
|
||||
|
||||
|
||||
// for i in 0..LEN {
|
||||
// for j in 0..LEN {
|
||||
// let (a, b) = (c[i + j * LEN], cref[i + j * LEN]);
|
||||
// assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
fn feq(a: f32, b: f32) -> bool {
|
||||
if a == b {
|
||||
true
|
||||
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
|
||||
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
|
||||
} else {
|
||||
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
|
||||
}
|
||||
}
|
283
src/matrix.rs
Normal file
283
src/matrix.rs
Normal file
@ -0,0 +1,283 @@
|
||||
pub(crate) trait Matrix: Copy {
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self;
|
||||
unsafe fn is_transposed(&self) -> bool;
|
||||
unsafe fn stride(&self) -> usize;
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32;
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32;
|
||||
unsafe fn row(&self, row: usize) -> *const f32;
|
||||
unsafe fn col(&self, col: usize) -> *const f32;
|
||||
}
|
||||
|
||||
pub(crate) trait MatrixMut: Matrix {
|
||||
unsafe fn set(&self, row: usize, col: usize, val: f32);
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32;
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32;
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct ConstMatrix {
|
||||
pub stride: usize,
|
||||
pub ptr: *const f32
|
||||
}
|
||||
|
||||
unsafe impl Send for ConstMatrix {}
|
||||
unsafe impl Sync for ConstMatrix {}
|
||||
|
||||
impl ConstMatrix {
|
||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for ConstMatrix {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index(row, col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn stride(&self) -> usize {
|
||||
self.stride
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
*self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct ConstTransposedMatrix {
|
||||
pub stride: usize,
|
||||
pub ptr: *const f32
|
||||
}
|
||||
|
||||
unsafe impl Send for ConstTransposedMatrix {}
|
||||
unsafe impl Sync for ConstTransposedMatrix {}
|
||||
|
||||
impl ConstTransposedMatrix {
|
||||
pub fn new(ptr: *const f32, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for ConstTransposedMatrix {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index(row, col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn stride(&self) -> usize {
|
||||
self.stride
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
*self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
self.ptr.add(row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct MutMatrix {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut f32
|
||||
}
|
||||
|
||||
unsafe impl Send for MutMatrix {}
|
||||
unsafe impl Sync for MutMatrix {}
|
||||
|
||||
impl MutMatrix {
|
||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for MutMatrix {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index_mut(row, col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn stride(&self) -> usize {
|
||||
self.stride
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
*self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatrixMut for MutMatrix {
|
||||
#[inline]
|
||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
||||
*self.ptr.add(row * self.stride + col) = value;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
||||
self.ptr.add(row * self.stride + col)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
||||
self.ptr.add(row * self.stride)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
||||
self.ptr.add(col)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct MutTransposedMatrix {
|
||||
pub stride: usize,
|
||||
pub ptr: *mut f32
|
||||
}
|
||||
|
||||
unsafe impl Send for MutTransposedMatrix {}
|
||||
unsafe impl Sync for MutTransposedMatrix {}
|
||||
|
||||
impl MutTransposedMatrix {
|
||||
pub fn new(ptr: *mut f32, stride: usize) -> Self {
|
||||
Self { ptr, stride }
|
||||
}
|
||||
}
|
||||
|
||||
impl Matrix for MutTransposedMatrix {
|
||||
#[inline]
|
||||
unsafe fn sub(&self, row: usize, col: usize) -> Self {
|
||||
Self {
|
||||
ptr: self.index_mut(row, col),
|
||||
stride: self.stride
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn is_transposed(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn stride(&self) -> usize {
|
||||
self.stride
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn get(&self, row: usize, col: usize) -> f32 {
|
||||
*self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride + row) as *const f32
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row(&self, row: usize) -> *const f32 {
|
||||
self.ptr.add(row) as *const f32
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col(&self, col: usize) -> *const f32 {
|
||||
self.ptr.add(col * self.stride) as *const f32
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl MatrixMut for MutTransposedMatrix {
|
||||
#[inline]
|
||||
unsafe fn set(&self, row: usize, col: usize, value: f32) {
|
||||
*self.ptr.add(col * self.stride + row) = value;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
|
||||
self.ptr.add(col * self.stride + row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
|
||||
self.ptr.add(row)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
|
||||
self.ptr.add(col * self.stride)
|
||||
}
|
||||
}
|
||||
|
||||
|
273
src/sgemm.rs
Normal file
273
src/sgemm.rs
Normal file
@ -0,0 +1,273 @@
|
||||
use crate::aligned_alloc;
|
||||
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
|
||||
use crate::kernel::params::single::*;
|
||||
use crate::kernel;
|
||||
|
||||
|
||||
pub(crate) unsafe fn sgemm<A, B, C>(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C
|
||||
)
|
||||
where A: Matrix,
|
||||
B: Matrix,
|
||||
C: MatrixMut,
|
||||
{
|
||||
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<f32>());
|
||||
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<f32>());
|
||||
|
||||
for j in (0..n).step_by(NC) {
|
||||
let j_b = std::cmp::min(n - j, NC);
|
||||
for p in (0..k).step_by(KC) {
|
||||
let p_b = std::cmp::min(k - p, KC);
|
||||
for i in (0..m).step_by(MC) {
|
||||
let i_b = std::cmp::min(m - i, MC);
|
||||
|
||||
let pa = MutMatrix::new(packed_a.ptr_f32(), p_b);
|
||||
let pb = MutMatrix::new(packed_b.ptr_f32(), p_b);
|
||||
|
||||
inner_kernel(
|
||||
i_b,
|
||||
j_b,
|
||||
p_b,
|
||||
alpha,
|
||||
a.sub(p, i),
|
||||
b.sub(j, p),
|
||||
beta,
|
||||
c.sub(j, i),
|
||||
pa,
|
||||
pb,
|
||||
i == 0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// +----------------------+
|
||||
// |A : : : |B |
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// + - - - - - - - - - ---+
|
||||
// | : : : | |
|
||||
// | : : : | |
|
||||
// +----------------------+
|
||||
// |C | | | |D |
|
||||
// +----------------------+
|
||||
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_pa<A: Matrix>(a: A, pa: MutMatrix, i: usize) {
|
||||
if a.is_transposed() {
|
||||
kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
} else {
|
||||
kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_pb<B: Matrix>(b: B, pb: MutMatrix, j: usize) {
|
||||
if b.is_transposed() {
|
||||
kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
} else {
|
||||
kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_ukr<C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
pa: MutMatrix,
|
||||
pb: MutMatrix,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
kernel::sgemm_ukr(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_sup1<A: Matrix, C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
pb: MutMatrix,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
if a.is_transposed() {
|
||||
kernel::sgemm_sup1_t(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a.col(i),
|
||||
a.stride(),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
} else {
|
||||
kernel::sgemm_sup1(
|
||||
pb.stride,
|
||||
alpha,
|
||||
a.col(i),
|
||||
a.stride(),
|
||||
pb.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i),
|
||||
c.stride())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sgemm_sup0<B: Matrix, C: MatrixMut>(
|
||||
i: usize, j: usize,
|
||||
alpha: f32,
|
||||
pa: MutMatrix,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C)
|
||||
{
|
||||
if c.is_transposed() {
|
||||
unimplemented!()
|
||||
} else {
|
||||
if b.is_transposed() {
|
||||
kernel::sgemm_sup0_t(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
b.row(j),
|
||||
b.stride(),
|
||||
beta,
|
||||
c.index_mut(j, i))
|
||||
} else {
|
||||
kernel::sgemm_sup0(
|
||||
pa.stride,
|
||||
alpha,
|
||||
pa.row(i),
|
||||
b.row(j),
|
||||
beta,
|
||||
c.index_mut(j, i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn inner_kernel<A, B, C>(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: A,
|
||||
b: B,
|
||||
beta: f32,
|
||||
c: C,
|
||||
pa: MutMatrix,
|
||||
pb: MutMatrix,
|
||||
first_time: bool,
|
||||
)
|
||||
where A: Matrix,
|
||||
B: Matrix,
|
||||
C: MatrixMut,
|
||||
|
||||
{
|
||||
let n_left = n % NR;
|
||||
let n_main = n - n_left;
|
||||
|
||||
let m_left = m % MR;
|
||||
let m_main = m - m_left;
|
||||
|
||||
if first_time {
|
||||
for j in (0..n_main).step_by(NR) {
|
||||
sgemm_pb(b, pb, j);
|
||||
}
|
||||
}
|
||||
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_pa(a, pa, i);
|
||||
};
|
||||
|
||||
|
||||
for j in (0..n_main).step_by(NR) {
|
||||
|
||||
// Section A
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_ukr(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
pa,
|
||||
pb,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
|
||||
// Section B
|
||||
for i in m_main..m {
|
||||
sgemm_sup1(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
a,
|
||||
pb,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for j in n_main..n {
|
||||
|
||||
// Section C
|
||||
for i in (0..m_main).step_by(MR) {
|
||||
sgemm_sup0(
|
||||
i,
|
||||
j,
|
||||
alpha,
|
||||
pa,
|
||||
b,
|
||||
beta,
|
||||
c,
|
||||
);
|
||||
}
|
||||
|
||||
// Section D
|
||||
for i in m_main..m {
|
||||
let mut elem = 0.0;
|
||||
|
||||
for p in 0..k {
|
||||
elem += a.get(p, i) * b.get(j, p);
|
||||
}
|
||||
|
||||
elem *= alpha;
|
||||
|
||||
if beta != 0.0 {
|
||||
elem += beta * c.get(j, i);
|
||||
}
|
||||
|
||||
c.set(j, i, elem);
|
||||
}
|
||||
}
|
||||
}
|
4
src/stats.sh
Normal file
4
src/stats.sh
Normal file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
cargo build --release;
|
||||
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm
|
0
src/utils/mod.rs
Normal file
0
src/utils/mod.rs
Normal file
5
stats.sh
Executable file
5
stats.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
|
||||
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
|
||||
cargo clean && cargo build --release
|
||||
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm
|
Loading…
Reference in New Issue
Block a user