Working state

This commit is contained in:
Andrey Tkachenko 2019-11-29 23:28:01 +04:00
commit ffae99bafd
24 changed files with 2787 additions and 0 deletions

2
.cargo/config Normal file
View File

@ -0,0 +1,2 @@
[build]
rustflags = ["-C", "target-cpu=native"]

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
**/*.rs.bk

164
Cargo.lock generated Normal file
View File

@ -0,0 +1,164 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "autocfg"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "blas"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "blas-sys"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "c2-chacha"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "gemm"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "getrandom"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "libc"
version = "0.2.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "matrixmultiply"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-complex"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-traits"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "openblas"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ppv-lite86"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rand"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_chacha"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasi"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[metadata]
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
"checksum matrixmultiply 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d4f7ec66360130972f34830bfad9ef05c6610a43938a467bcc9ab9369ab3478f"
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"

14
Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "gemm"
version = "0.1.0"
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
edition = "2018"
build = "build.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.7.2"
matrixmultiply = "0.2.3"
blas = "0.20.0"
openblas = "0.0.1"

7
build.rs Normal file
View File

@ -0,0 +1,7 @@
fn main() {
// Always linking openblas
// Compiling blas just for testing is tedious -- install it on your system
// and run this.
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rustc-link-lib={}=openblas", "dylib");
}

31
src/aligned_alloc.rs Normal file
View File

@ -0,0 +1,31 @@
use std::alloc;
pub struct Alloc {
ptr: *mut u8,
layout: alloc::Layout,
}
impl Alloc {
pub fn new(size: usize) -> Alloc {
const ALIGN: usize = 32;
let layout = alloc::Layout::from_size_align(size, ALIGN).unwrap();
let ptr = unsafe { alloc::alloc(layout) };
Alloc { ptr, layout }
}
pub fn ptr_f32(&self) -> *mut f32 {
self.ptr as *mut f32
}
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
}
impl Drop for Alloc {
fn drop(&mut self) {
unsafe {
alloc::dealloc(self.ptr, self.layout);
}
}
}

47
src/kernel/avx/fma.rs Normal file
View File

@ -0,0 +1,47 @@
use super::intrinsics::*;
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_fmadd_ps(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
#[inline(always)]
pub unsafe fn fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_add_ps(_mm256_mul_ps(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_fmsub_ps(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
pub unsafe fn fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_sub_ps(_mm256_mul_ps(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_fmadd_pd(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
#[inline(always)]
pub unsafe fn fmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_add_pd(_mm256_mul_pd(a, b), c)
}
#[cfg(target_feature = "fma")]
#[inline(always)]
pub unsafe fn fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_fmsub_pd(a, b, c)
}
#[cfg(not(target_feature = "fma"))]
pub unsafe fn fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_sub_pd(_mm256_mul_pd(a, b), c)
}

23
src/kernel/avx/hsum.rs Normal file
View File

@ -0,0 +1,23 @@
use super::intrinsics::*;
#[inline(always)]
pub unsafe fn hsum_ps(v: __m256) -> f32 {
let qhigh = _mm256_extractf128_ps(v, 1);
let qlow = _mm256_castps256_ps128(v);
let qsum = _mm_add_ps(qhigh, qlow);
let dhigh = _mm_movehl_ps(qsum, qsum);
let dlow = qsum;
let dsum = _mm_add_ps(dhigh, dlow);
let high = _mm_shuffle_ps(dsum, dsum, 1);
let low = dsum;
_mm_cvtss_f32(_mm_add_ss(high, low))
}
#[inline(always)]
pub unsafe fn hsum_pd(v: __m256d) -> f64 {
let vhigh = _mm256_extractf128_pd(v, 1);
let vlow = _mm256_castpd256_pd128(v);
let vsum = _mm_add_pd(vlow, vhigh);
let h64 = _mm_unpackhi_pd(vsum, vsum);
_mm_cvtsd_f64(_mm_add_sd(vsum, h64))
}

View File

@ -0,0 +1,5 @@
#[cfg(target_arch = "x86")]
pub use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
pub use std::arch::x86_64::*;

396
src/kernel/avx/l1d.rs Normal file
View File

@ -0,0 +1,396 @@
use super::fma::{fmadd_pd, fmsub_pd};
use super::hsum::hsum_pd;
use super::intrinsics::*;
pub unsafe fn drot(
n: usize,
mut x: *mut f64,
incx: usize,
mut y: *mut f64,
incy: usize,
c: f64,
s: f64,
) {
if incx == 1 && incy == 1 {
let c0 = _mm256_broadcast_sd(&c);
let s0 = _mm256_broadcast_sd(&s);
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(x, fmadd_pd(c0, x0, _mm256_mul_pd(s0, y0)));
_mm256_storeu_pd(y, fmsub_pd(c0, y0, _mm256_mul_pd(s0, x0)));
_mm256_storeu_pd(x.add(4), fmadd_pd(c0, x1, _mm256_mul_pd(s0, y1)));
_mm256_storeu_pd(y.add(4), fmsub_pd(c0, y1, _mm256_mul_pd(s0, x1)));
_mm256_storeu_pd(x.add(8), fmadd_pd(c0, x2, _mm256_mul_pd(s0, y2)));
_mm256_storeu_pd(y.add(8), fmsub_pd(c0, y2, _mm256_mul_pd(s0, x2)));
_mm256_storeu_pd(x.add(12), fmadd_pd(c0, x3, _mm256_mul_pd(s0, y3)));
_mm256_storeu_pd(y.add(12), fmsub_pd(c0, y3, _mm256_mul_pd(s0, x3)));
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn dswap(n: usize, mut x: *mut f64, incx: usize, mut y: *mut f64, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(x, y0);
_mm256_storeu_pd(y, x0);
_mm256_storeu_pd(x.add(4), y1);
_mm256_storeu_pd(y.add(4), x1);
_mm256_storeu_pd(x.add(8), y2);
_mm256_storeu_pd(y.add(8), x2);
_mm256_storeu_pd(x.add(12), y3);
_mm256_storeu_pd(y.add(12), x3);
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn dscal(n: usize, a: f64, mut x: *mut f64, incx: usize) {
if incx == 1 {
let a0 = _mm256_broadcast_sd(&a);
for _ in 0..n / 32 {
let mut x0 = _mm256_loadu_pd(x);
let mut x1 = _mm256_loadu_pd(x.add(4));
let mut x2 = _mm256_loadu_pd(x.add(8));
let mut x3 = _mm256_loadu_pd(x.add(12));
let mut x4 = _mm256_loadu_pd(x.add(16));
let mut x5 = _mm256_loadu_pd(x.add(20));
let mut x6 = _mm256_loadu_pd(x.add(24));
let mut x7 = _mm256_loadu_pd(x.add(28));
x0 = _mm256_mul_pd(a0, x0);
x1 = _mm256_mul_pd(a0, x1);
x2 = _mm256_mul_pd(a0, x2);
x3 = _mm256_mul_pd(a0, x3);
x4 = _mm256_mul_pd(a0, x4);
x5 = _mm256_mul_pd(a0, x5);
x6 = _mm256_mul_pd(a0, x6);
x7 = _mm256_mul_pd(a0, x7);
_mm256_storeu_pd(x, x0);
_mm256_storeu_pd(x.add(4), x1);
_mm256_storeu_pd(x.add(8), x2);
_mm256_storeu_pd(x.add(12), x3);
_mm256_storeu_pd(x.add(16), x4);
_mm256_storeu_pd(x.add(20), x5);
_mm256_storeu_pd(x.add(24), x6);
_mm256_storeu_pd(x.add(28), x7);
x = x.add(32);
}
for _ in 0..n % 32 {
*x *= a;
x = x.add(1);
}
} else {
for _ in 0..n {
*x *= a;
x = x.add(incx);
}
}
}
pub unsafe fn dcopy(n: usize, mut x: *const f64, incx: usize, mut y: *mut f64, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 32 {
let x0 = _mm256_loadu_pd(x);
let x1 = _mm256_loadu_pd(x.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let x4 = _mm256_loadu_pd(x.add(16));
let x5 = _mm256_loadu_pd(x.add(20));
let x6 = _mm256_loadu_pd(x.add(24));
let x7 = _mm256_loadu_pd(x.add(28));
_mm256_storeu_pd(y, x0);
_mm256_storeu_pd(y.add(4), x1);
_mm256_storeu_pd(y.add(8), x2);
_mm256_storeu_pd(y.add(12), x3);
_mm256_storeu_pd(y.add(16), x4);
_mm256_storeu_pd(y.add(20), x5);
_mm256_storeu_pd(y.add(24), x6);
_mm256_storeu_pd(y.add(28), x7);
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
*y = *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y = *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn daxpy(
n: usize,
a: f64,
mut x: *const f64,
incx: usize,
mut y: *mut f64,
incy: usize,
) {
if incx == 1 && incy == 1 {
let a0 = _mm256_broadcast_sd(&a);
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
_mm256_storeu_pd(y, fmadd_pd(a0, x0, y0));
_mm256_storeu_pd(y.add(4), fmadd_pd(a0, x1, y1));
_mm256_storeu_pd(y.add(8), fmadd_pd(a0, x2, y2));
_mm256_storeu_pd(y.add(12), fmadd_pd(a0, x3, y3));
x = x.add(16);
y = y.add(16);
}
for _ in 0..n % 16 {
*y += a * *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y += a * *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn ddot(
n: usize,
mut x: *const f64,
incx: usize,
mut y: *const f64,
incy: usize,
) -> f64 {
if incx == 1 && incy == 1 {
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let y0 = _mm256_loadu_pd(y);
let x1 = _mm256_loadu_pd(x.add(4));
let y1 = _mm256_loadu_pd(y.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let y2 = _mm256_loadu_pd(y.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
let y3 = _mm256_loadu_pd(y.add(12));
acc0 = fmadd_pd(x0, y0, acc0);
acc1 = fmadd_pd(x1, y1, acc1);
acc2 = fmadd_pd(x2, y2, acc2);
acc3 = fmadd_pd(x3, y3, acc3);
x = x.add(16);
y = y.add(16);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc0 = _mm256_add_pd(acc0, acc2);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 16 {
acc += *x * *y;
x = x.add(1);
y = y.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += *x * *y;
x = x.add(incx);
y = y.add(incy);
}
acc
}
}
pub unsafe fn dnrm2(n: usize, mut x: *const f64, incx: usize) -> f64 {
if incx == 1 {
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
for _ in 0..n / 16 {
let x0 = _mm256_loadu_pd(x);
let x1 = _mm256_loadu_pd(x.add(4));
let x2 = _mm256_loadu_pd(x.add(8));
let x3 = _mm256_loadu_pd(x.add(12));
acc0 = fmadd_pd(x0, x0, acc0);
acc1 = fmadd_pd(x1, x1, acc1);
acc2 = fmadd_pd(x2, x2, acc2);
acc3 = fmadd_pd(x3, x3, acc3);
x = x.add(16);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc0 = _mm256_add_pd(acc0, acc2);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 16 {
let x0 = *x;
acc += x0 * x0;
x = x.add(1);
}
acc.sqrt()
} else {
let mut acc = 0.0;
for _ in 0..n {
let x0 = *x;
acc += x0 * x0;
x = x.add(incx);
}
acc.sqrt()
}
}
pub unsafe fn dasum(n: usize, mut x: *const f64, incx: usize) -> f64 {
if incx == 1 {
let mask = _mm256_broadcast_sd(&f64::from_bits(0x7FFF_FFFF_FFFF_FFFF));
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut acc2 = _mm256_setzero_pd();
let mut acc3 = _mm256_setzero_pd();
let mut acc4 = _mm256_setzero_pd();
let mut acc5 = _mm256_setzero_pd();
let mut acc6 = _mm256_setzero_pd();
let mut acc7 = _mm256_setzero_pd();
for _ in 0..n / 32 {
let mut x0 = _mm256_loadu_pd(x);
let mut x1 = _mm256_loadu_pd(x.add(4));
let mut x2 = _mm256_loadu_pd(x.add(8));
let mut x3 = _mm256_loadu_pd(x.add(12));
let mut x4 = _mm256_loadu_pd(x.add(16));
let mut x5 = _mm256_loadu_pd(x.add(20));
let mut x6 = _mm256_loadu_pd(x.add(24));
let mut x7 = _mm256_loadu_pd(x.add(28));
x0 = _mm256_and_pd(mask, x0);
x1 = _mm256_and_pd(mask, x1);
x2 = _mm256_and_pd(mask, x2);
x3 = _mm256_and_pd(mask, x3);
x4 = _mm256_and_pd(mask, x4);
x5 = _mm256_and_pd(mask, x5);
x6 = _mm256_and_pd(mask, x6);
x7 = _mm256_and_pd(mask, x7);
acc0 = _mm256_add_pd(acc0, x0);
acc1 = _mm256_add_pd(acc1, x1);
acc2 = _mm256_add_pd(acc2, x2);
acc3 = _mm256_add_pd(acc3, x3);
acc4 = _mm256_add_pd(acc4, x4);
acc5 = _mm256_add_pd(acc5, x5);
acc6 = _mm256_add_pd(acc6, x6);
acc7 = _mm256_add_pd(acc7, x7);
x = x.add(32);
}
acc0 = _mm256_add_pd(acc0, acc1);
acc2 = _mm256_add_pd(acc2, acc3);
acc4 = _mm256_add_pd(acc4, acc5);
acc6 = _mm256_add_pd(acc6, acc7);
acc0 = _mm256_add_pd(acc0, acc2);
acc4 = _mm256_add_pd(acc4, acc6);
acc0 = _mm256_add_pd(acc0, acc4);
let mut acc = hsum_pd(acc0);
for _ in 0..n % 32 {
acc += (*x).abs();
x = x.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += (*x).abs();
x = x.add(incx);
}
acc
}
}

396
src/kernel/avx/l1s.rs Normal file
View File

@ -0,0 +1,396 @@
use super::fma::{fmadd_ps, fmsub_ps};
use super::hsum::hsum_ps;
use super::intrinsics::*;
pub unsafe fn srot(
n: usize,
mut x: *mut f32,
incx: usize,
mut y: *mut f32,
incy: usize,
c: f32,
s: f32,
) {
if incx == 1 && incy == 1 {
let c0 = _mm256_broadcast_ss(&c);
let s0 = _mm256_broadcast_ss(&s);
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(x, fmadd_ps(c0, x0, _mm256_mul_ps(s0, y0)));
_mm256_storeu_ps(y, fmsub_ps(c0, y0, _mm256_mul_ps(s0, x0)));
_mm256_storeu_ps(x.add(8), fmadd_ps(c0, x1, _mm256_mul_ps(s0, y1)));
_mm256_storeu_ps(y.add(8), fmsub_ps(c0, y1, _mm256_mul_ps(s0, x1)));
_mm256_storeu_ps(x.add(16), fmadd_ps(c0, x2, _mm256_mul_ps(s0, y2)));
_mm256_storeu_ps(y.add(16), fmsub_ps(c0, y2, _mm256_mul_ps(s0, x2)));
_mm256_storeu_ps(x.add(24), fmadd_ps(c0, x3, _mm256_mul_ps(s0, y3)));
_mm256_storeu_ps(y.add(24), fmsub_ps(c0, y3, _mm256_mul_ps(s0, x3)));
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
let y0 = *y;
*x = c * x0 + s * y0;
*y = c * y0 - s * x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sswap(n: usize, mut x: *mut f32, incx: usize, mut y: *mut f32, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(x, y0);
_mm256_storeu_ps(y, x0);
_mm256_storeu_ps(x.add(8), y1);
_mm256_storeu_ps(y.add(8), x1);
_mm256_storeu_ps(x.add(16), y2);
_mm256_storeu_ps(y.add(16), x2);
_mm256_storeu_ps(x.add(24), y3);
_mm256_storeu_ps(y.add(24), x3);
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
let x0 = *x;
*x = *y;
*y = x0;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sscal(n: usize, a: f32, mut x: *mut f32, incx: usize) {
if incx == 1 {
let a0 = _mm256_broadcast_ss(&a);
for _ in 0..n / 64 {
let mut x0 = _mm256_loadu_ps(x);
let mut x1 = _mm256_loadu_ps(x.add(8));
let mut x2 = _mm256_loadu_ps(x.add(16));
let mut x3 = _mm256_loadu_ps(x.add(24));
let mut x4 = _mm256_loadu_ps(x.add(32));
let mut x5 = _mm256_loadu_ps(x.add(40));
let mut x6 = _mm256_loadu_ps(x.add(48));
let mut x7 = _mm256_loadu_ps(x.add(56));
x0 = _mm256_mul_ps(a0, x0);
x1 = _mm256_mul_ps(a0, x1);
x2 = _mm256_mul_ps(a0, x2);
x3 = _mm256_mul_ps(a0, x3);
x4 = _mm256_mul_ps(a0, x4);
x5 = _mm256_mul_ps(a0, x5);
x6 = _mm256_mul_ps(a0, x6);
x7 = _mm256_mul_ps(a0, x7);
_mm256_storeu_ps(x, x0);
_mm256_storeu_ps(x.add(8), x1);
_mm256_storeu_ps(x.add(16), x2);
_mm256_storeu_ps(x.add(24), x3);
_mm256_storeu_ps(x.add(32), x4);
_mm256_storeu_ps(x.add(40), x5);
_mm256_storeu_ps(x.add(48), x6);
_mm256_storeu_ps(x.add(56), x7);
x = x.add(64);
}
for _ in 0..n % 64 {
*x *= a;
x = x.add(1);
}
} else {
for _ in 0..n {
*x *= a;
x = x.add(incx);
}
}
}
pub unsafe fn scopy(n: usize, mut x: *const f32, incx: usize, mut y: *mut f32, incy: usize) {
if incx == 1 && incy == 1 {
for _ in 0..n / 64 {
let x0 = _mm256_loadu_ps(x);
let x1 = _mm256_loadu_ps(x.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let x4 = _mm256_loadu_ps(x.add(32));
let x5 = _mm256_loadu_ps(x.add(40));
let x6 = _mm256_loadu_ps(x.add(48));
let x7 = _mm256_loadu_ps(x.add(56));
_mm256_storeu_ps(y, x0);
_mm256_storeu_ps(y.add(8), x1);
_mm256_storeu_ps(y.add(16), x2);
_mm256_storeu_ps(y.add(24), x3);
_mm256_storeu_ps(y.add(32), x4);
_mm256_storeu_ps(y.add(40), x5);
_mm256_storeu_ps(y.add(48), x6);
_mm256_storeu_ps(y.add(56), x7);
x = x.add(64);
y = y.add(64);
}
for _ in 0..n % 64 {
*y = *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y = *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn saxpy(
n: usize,
a: f32,
mut x: *const f32,
incx: usize,
mut y: *mut f32,
incy: usize,
) {
if incx == 1 && incy == 1 {
let a0 = _mm256_broadcast_ss(&a);
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
_mm256_storeu_ps(y, fmadd_ps(a0, x0, y0));
_mm256_storeu_ps(y.add(8), fmadd_ps(a0, x1, y1));
_mm256_storeu_ps(y.add(16), fmadd_ps(a0, x2, y2));
_mm256_storeu_ps(y.add(24), fmadd_ps(a0, x3, y3));
x = x.add(32);
y = y.add(32);
}
for _ in 0..n % 32 {
*y += a * *x;
x = x.add(1);
y = y.add(1);
}
} else {
for _ in 0..n {
*y += a * *x;
x = x.add(incx);
y = y.add(incy);
}
}
}
pub unsafe fn sdot(
n: usize,
mut x: *const f32,
incx: usize,
mut y: *const f32,
incy: usize,
) -> f32 {
if incx == 1 && incy == 1 {
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let y0 = _mm256_loadu_ps(y);
let x1 = _mm256_loadu_ps(x.add(8));
let y1 = _mm256_loadu_ps(y.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let y2 = _mm256_loadu_ps(y.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
let y3 = _mm256_loadu_ps(y.add(24));
acc0 = fmadd_ps(x0, y0, acc0);
acc1 = fmadd_ps(x1, y1, acc1);
acc2 = fmadd_ps(x2, y2, acc2);
acc3 = fmadd_ps(x3, y3, acc3);
x = x.add(32);
y = y.add(32);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc0 = _mm256_add_ps(acc0, acc2);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 32 {
acc += *x * *y;
x = x.add(1);
y = y.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += *x * *y;
x = x.add(incx);
y = y.add(incy);
}
acc
}
}
pub unsafe fn snrm2(n: usize, mut x: *const f32, incx: usize) -> f32 {
if incx == 1 {
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for _ in 0..n / 32 {
let x0 = _mm256_loadu_ps(x);
let x1 = _mm256_loadu_ps(x.add(8));
let x2 = _mm256_loadu_ps(x.add(16));
let x3 = _mm256_loadu_ps(x.add(24));
acc0 = fmadd_ps(x0, x0, acc0);
acc1 = fmadd_ps(x1, x1, acc1);
acc2 = fmadd_ps(x2, x2, acc2);
acc3 = fmadd_ps(x3, x3, acc3);
x = x.add(32);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc0 = _mm256_add_ps(acc0, acc2);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 32 {
let x0 = *x;
acc += x0 * x0;
x = x.add(1);
}
acc.sqrt()
} else {
let mut acc = 0.0;
for _ in 0..n {
let x0 = *x;
acc += x0 * x0;
x = x.add(incx);
}
acc.sqrt()
}
}
pub unsafe fn sasum(n: usize, mut x: *const f32, incx: usize) -> f32 {
if incx == 1 {
let mask = _mm256_broadcast_ss(&f32::from_bits(0x7FFF_FFFF));
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut acc4 = _mm256_setzero_ps();
let mut acc5 = _mm256_setzero_ps();
let mut acc6 = _mm256_setzero_ps();
let mut acc7 = _mm256_setzero_ps();
for _ in 0..n / 64 {
let mut x0 = _mm256_loadu_ps(x);
let mut x1 = _mm256_loadu_ps(x.add(8));
let mut x2 = _mm256_loadu_ps(x.add(16));
let mut x3 = _mm256_loadu_ps(x.add(24));
let mut x4 = _mm256_loadu_ps(x.add(32));
let mut x5 = _mm256_loadu_ps(x.add(40));
let mut x6 = _mm256_loadu_ps(x.add(48));
let mut x7 = _mm256_loadu_ps(x.add(56));
x0 = _mm256_and_ps(mask, x0);
x1 = _mm256_and_ps(mask, x1);
x2 = _mm256_and_ps(mask, x2);
x3 = _mm256_and_ps(mask, x3);
x4 = _mm256_and_ps(mask, x4);
x5 = _mm256_and_ps(mask, x5);
x6 = _mm256_and_ps(mask, x6);
x7 = _mm256_and_ps(mask, x7);
acc0 = _mm256_add_ps(acc0, x0);
acc1 = _mm256_add_ps(acc1, x1);
acc2 = _mm256_add_ps(acc2, x2);
acc3 = _mm256_add_ps(acc3, x3);
acc4 = _mm256_add_ps(acc4, x4);
acc5 = _mm256_add_ps(acc5, x5);
acc6 = _mm256_add_ps(acc6, x6);
acc7 = _mm256_add_ps(acc7, x7);
x = x.add(64);
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc4 = _mm256_add_ps(acc4, acc5);
acc6 = _mm256_add_ps(acc6, acc7);
acc0 = _mm256_add_ps(acc0, acc2);
acc4 = _mm256_add_ps(acc4, acc6);
acc0 = _mm256_add_ps(acc0, acc4);
let mut acc = hsum_ps(acc0);
for _ in 0..n % 64 {
acc += (*x).abs();
x = x.add(1);
}
acc
} else {
let mut acc = 0.0;
for _ in 0..n {
acc += (*x).abs();
x = x.add(incx);
}
acc
}
}

140
src/kernel/avx/l3d.rs Normal file
View File

@ -0,0 +1,140 @@
use super::fma::fmadd_pd;
use super::intrinsics::*;
pub(crate) unsafe fn dgemm_ukr_8x4(
k: usize,
alpha: f64,
pa: *const f64,
pb: *const f64,
beta: f64,
c: *mut f64,
ldc: usize,
) {
let mut mt00 = _mm256_setzero_pd();
let mut mt01 = _mm256_setzero_pd();
let mut mt02 = _mm256_setzero_pd();
let mut mt03 = _mm256_setzero_pd();
let mut mt10 = _mm256_setzero_pd();
let mut mt11 = _mm256_setzero_pd();
let mut mt12 = _mm256_setzero_pd();
let mut mt13 = _mm256_setzero_pd();
let mut pa = pa;
let mut pb = pb;
for _ in 0..k {
let a0 = _mm256_loadu_pd(pa);
let a1 = _mm256_loadu_pd(pa.add(4));
let b0 = _mm256_broadcast_sd(&*pb);
let b1 = _mm256_broadcast_sd(&*pb.add(1));
let b2 = _mm256_broadcast_sd(&*pb.add(2));
let b3 = _mm256_broadcast_sd(&*pb.add(3));
mt00 = fmadd_pd(a0, b0, mt00);
mt01 = fmadd_pd(a0, b1, mt01);
mt02 = fmadd_pd(a0, b2, mt02);
mt03 = fmadd_pd(a0, b3, mt03);
mt10 = fmadd_pd(a1, b0, mt10);
mt11 = fmadd_pd(a1, b1, mt11);
mt12 = fmadd_pd(a1, b2, mt12);
mt13 = fmadd_pd(a1, b3, mt13);
pa = pa.add(8);
pb = pb.add(4);
}
let alpha = _mm256_broadcast_sd(&alpha);
mt00 = _mm256_mul_pd(alpha, mt00);
mt01 = _mm256_mul_pd(alpha, mt01);
mt02 = _mm256_mul_pd(alpha, mt02);
mt03 = _mm256_mul_pd(alpha, mt03);
mt10 = _mm256_mul_pd(alpha, mt10);
mt11 = _mm256_mul_pd(alpha, mt11);
mt12 = _mm256_mul_pd(alpha, mt12);
mt13 = _mm256_mul_pd(alpha, mt13);
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
let beta = _mm256_broadcast_sd(&beta);
mt00 = fmadd_pd(beta, _mm256_loadu_pd(ccol0), mt00);
mt01 = fmadd_pd(beta, _mm256_loadu_pd(ccol1), mt01);
mt02 = fmadd_pd(beta, _mm256_loadu_pd(ccol2), mt02);
mt03 = fmadd_pd(beta, _mm256_loadu_pd(ccol3), mt03);
mt10 = fmadd_pd(beta, _mm256_loadu_pd(ccol0.add(4)), mt10);
mt11 = fmadd_pd(beta, _mm256_loadu_pd(ccol1.add(4)), mt11);
mt12 = fmadd_pd(beta, _mm256_loadu_pd(ccol2.add(4)), mt12);
mt13 = fmadd_pd(beta, _mm256_loadu_pd(ccol3.add(4)), mt13);
}
_mm256_storeu_pd(ccol0, mt00);
_mm256_storeu_pd(ccol1, mt01);
_mm256_storeu_pd(ccol2, mt02);
_mm256_storeu_pd(ccol3, mt03);
_mm256_storeu_pd(ccol0.add(4), mt10);
_mm256_storeu_pd(ccol1.add(4), mt11);
_mm256_storeu_pd(ccol2.add(4), mt12);
_mm256_storeu_pd(ccol3.add(4), mt13);
}
pub(crate) unsafe fn dgemm_sup_8x1(
k: usize,
alpha: f64,
pa: *const f64,
b: *const f64,
beta: f64,
c: *mut f64,
) {
let mut mt0 = _mm256_setzero_pd();
let mut mt1 = _mm256_setzero_pd();
let mut pa = pa;
let mut b = b;
for _ in 0..k {
let a0 = _mm256_loadu_pd(pa);
let a1 = _mm256_loadu_pd(pa.add(4));
let b0 = _mm256_broadcast_sd(&*b);
mt0 = fmadd_pd(a0, b0, mt0);
mt1 = fmadd_pd(a1, b0, mt1);
pa = pa.add(8);
b = b.add(1);
}
let alpha = _mm256_broadcast_sd(&alpha);
mt0 = _mm256_mul_pd(alpha, mt0);
mt1 = _mm256_mul_pd(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_sd(&beta);
mt0 = fmadd_pd(beta, _mm256_loadu_pd(c), mt0);
mt1 = fmadd_pd(beta, _mm256_loadu_pd(c.add(4)), mt1);
}
_mm256_storeu_pd(c, mt0);
_mm256_storeu_pd(c.add(4), mt1);
}
pub(crate) unsafe fn dgemm_pa_8x(k: usize, a: *const f64, lda: usize, pa: *mut f64) {
let mut a = a;
let mut pa = pa;
for _ in 0..k {
_mm256_storeu_pd(pa, _mm256_loadu_pd(a));
_mm256_storeu_pd(pa.add(4), _mm256_loadu_pd(a.add(4)));
pa = pa.add(8);
a = a.add(lda);
}
}

185
src/kernel/avx/l3s.rs Normal file
View File

@ -0,0 +1,185 @@
use super::fma::fmadd_ps;
use super::intrinsics::*;
pub(crate) unsafe fn sgemm_ukr_16x4(
k: usize,
alpha: f32,
pa: *const f32,
pb: *const f32,
beta: f32,
c: *mut f32,
ldc: usize,
) {
let mut mt00 = _mm256_setzero_ps();
let mut mt01 = _mm256_setzero_ps();
let mut mt02 = _mm256_setzero_ps();
let mut mt03 = _mm256_setzero_ps();
let mut mt10 = _mm256_setzero_ps();
let mut mt11 = _mm256_setzero_ps();
let mut mt12 = _mm256_setzero_ps();
let mut mt13 = _mm256_setzero_ps();
let mut pa = pa;
let mut pb = pb;
for _ in 0..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
let b0 = _mm256_broadcast_ss(&*pb);
let b1 = _mm256_broadcast_ss(&*pb.add(1));
let b2 = _mm256_broadcast_ss(&*pb.add(2));
let b3 = _mm256_broadcast_ss(&*pb.add(3));
mt00 = fmadd_ps(a0, b0, mt00);
mt01 = fmadd_ps(a0, b1, mt01);
mt02 = fmadd_ps(a0, b2, mt02);
mt03 = fmadd_ps(a0, b3, mt03);
mt10 = fmadd_ps(a1, b0, mt10);
mt11 = fmadd_ps(a1, b1, mt11);
mt12 = fmadd_ps(a1, b2, mt12);
mt13 = fmadd_ps(a1, b3, mt13);
pa = pa.add(16);
pb = pb.add(4);
}
let alpha = _mm256_broadcast_ss(&alpha);
mt00 = _mm256_mul_ps(alpha, mt00);
mt01 = _mm256_mul_ps(alpha, mt01);
mt02 = _mm256_mul_ps(alpha, mt02);
mt03 = _mm256_mul_ps(alpha, mt03);
mt10 = _mm256_mul_ps(alpha, mt10);
mt11 = _mm256_mul_ps(alpha, mt11);
mt12 = _mm256_mul_ps(alpha, mt12);
mt13 = _mm256_mul_ps(alpha, mt13);
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
mt00 = fmadd_ps(beta, _mm256_loadu_ps(ccol0), mt00);
mt01 = fmadd_ps(beta, _mm256_loadu_ps(ccol1), mt01);
mt02 = fmadd_ps(beta, _mm256_loadu_ps(ccol2), mt02);
mt03 = fmadd_ps(beta, _mm256_loadu_ps(ccol3), mt03);
mt10 = fmadd_ps(beta, _mm256_loadu_ps(ccol0.add(8)), mt10);
mt11 = fmadd_ps(beta, _mm256_loadu_ps(ccol1.add(8)), mt11);
mt12 = fmadd_ps(beta, _mm256_loadu_ps(ccol2.add(8)), mt12);
mt13 = fmadd_ps(beta, _mm256_loadu_ps(ccol3.add(8)), mt13);
}
_mm256_storeu_ps(ccol0, mt00);
_mm256_storeu_ps(ccol1, mt01);
_mm256_storeu_ps(ccol2, mt02);
_mm256_storeu_ps(ccol3, mt03);
_mm256_storeu_ps(ccol0.add(8), mt10);
_mm256_storeu_ps(ccol1.add(8), mt11);
_mm256_storeu_ps(ccol2.add(8), mt12);
_mm256_storeu_ps(ccol3.add(8), mt13);
}
pub(crate) unsafe fn sgemm_sup_16x1(
k: usize,
alpha: f32,
pa: *const f32,
b: *const f32,
beta: f32,
c: *mut f32,
) {
let mut mt0 = _mm256_setzero_ps();
let mut mt1 = _mm256_setzero_ps();
let mut pa = pa;
let mut b = b;
for _ in 0..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
let b0 = _mm256_broadcast_ss(&*b);
mt0 = fmadd_ps(a0, b0, mt0);
mt1 = fmadd_ps(a1, b0, mt1);
pa = pa.add(16);
b = b.add(1);
}
let alpha = _mm256_broadcast_ss(&alpha);
mt0 = _mm256_mul_ps(alpha, mt0);
mt1 = _mm256_mul_ps(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
}
_mm256_storeu_ps(c, mt0);
_mm256_storeu_ps(c.add(8), mt1);
}
pub(crate) unsafe fn sgemm_sup0_t(
k: usize,
alpha: f32,
pa: *const f32,
b: *const f32,
ldb: usize,
beta: f32,
c: *mut f32,
) {
let mut mt0 = _mm256_setzero_ps();
let mut mt1 = _mm256_setzero_ps();
let mut pa = pa;
let mut b = b;
for _ in 0..k {
let a0 = _mm256_load_ps(pa);
let a1 = _mm256_load_ps(pa.add(8));
let b0 = _mm256_broadcast_ss(&*b);
mt0 = fmadd_ps(a0, b0, mt0);
mt1 = fmadd_ps(a1, b0, mt1);
pa = pa.add(16);
b = b.add(ldb);
}
let alpha = _mm256_broadcast_ss(&alpha);
mt0 = _mm256_mul_ps(alpha, mt0);
mt1 = _mm256_mul_ps(alpha, mt1);
if beta != 0.0 {
let beta = _mm256_broadcast_ss(&beta);
mt0 = fmadd_ps(beta, _mm256_loadu_ps(c), mt0);
mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
}
_mm256_storeu_ps(c, mt0);
_mm256_storeu_ps(c.add(8), mt1);
}
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
let mut a = a;
let mut pa = pa;
for _ in 0..k {
_mm256_store_ps(pa, _mm256_loadu_ps(a));
_mm256_store_ps(pa.add(8), _mm256_loadu_ps(a.add(8)));
pa = pa.add(16);
a = a.add(lda);
}
}

7
src/kernel/avx/mod.rs Normal file
View File

@ -0,0 +1,7 @@
mod fma;
mod hsum;
mod intrinsics;
pub mod l1d;
pub mod l1s;
pub mod l3d;
pub mod l3s;

74
src/kernel/generic/l3d.rs Normal file
View File

@ -0,0 +1,74 @@
pub(crate) unsafe fn dgemm_sup_1x4(
k: usize,
alpha: f64,
a: *const f64,
lda: usize,
pb: *const f64,
beta: f64,
c: *mut f64,
ldc: usize,
) {
let mut c0 = 0.0;
let mut c1 = 0.0;
let mut c2 = 0.0;
let mut c3 = 0.0;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
c0 += *pb * a0;
c1 += *pb.add(1) * a0;
c2 += *pb.add(2) * a0;
c3 += *pb.add(3) * a0;
a = a.add(lda);
pb = pb.add(4);
}
c0 *= alpha;
c1 *= alpha;
c2 *= alpha;
c3 *= alpha;
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
c0 *= beta * *ccol0;
c1 *= beta * *ccol1;
c2 *= beta * *ccol2;
c3 *= beta * *ccol3;
}
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
}
pub(crate) unsafe fn dgemm_pb_x4(k: usize, b: *const f64, ldb: usize, pb: *mut f64) {
let mut bcol0 = b;
let mut bcol1 = b.add(ldb);
let mut bcol2 = b.add(ldb * 2);
let mut bcol3 = b.add(ldb * 3);
let mut pb = pb;
for _ in 0..k {
*pb = *bcol0;
*pb.add(1) = *bcol1;
*pb.add(2) = *bcol2;
*pb.add(3) = *bcol3;
bcol0 = bcol0.add(1);
bcol1 = bcol1.add(1);
bcol2 = bcol2.add(1);
bcol3 = bcol3.add(1);
pb = pb.add(4);
}
}

169
src/kernel/generic/l3s.rs Normal file
View File

@ -0,0 +1,169 @@
pub(crate) unsafe fn sgemm_sup_1x4(
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
pb: *const f32,
beta: f32,
c: *mut f32,
ldc: usize,
) {
let mut c0 = 0.0;
let mut c1 = 0.0;
let mut c2 = 0.0;
let mut c3 = 0.0;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
c0 += *pb * a0;
c1 += *pb.add(1) * a0;
c2 += *pb.add(2) * a0;
c3 += *pb.add(3) * a0;
a = a.add(lda);
pb = pb.add(4);
}
c0 *= alpha;
c1 *= alpha;
c2 *= alpha;
c3 *= alpha;
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
c0 += beta * *ccol0;
c1 += beta * *ccol1;
c2 += beta * *ccol2;
c3 += beta * *ccol3;
}
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
}
pub(crate) unsafe fn sgemm_sup1_t(
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
pb: *const f32,
beta: f32,
c: *mut f32,
ldc: usize,
) {
let mut c0 = 0.0;
let mut c1 = 0.0;
let mut c2 = 0.0;
let mut c3 = 0.0;
let mut a = a;
let mut pb = pb;
for _ in 0..k {
let a0 = *a;
c0 += *pb * a0;
c1 += *pb.add(1) * a0;
c2 += *pb.add(2) * a0;
c3 += *pb.add(3) * a0;
a = a.add(1);
pb = pb.add(4);
}
c0 *= alpha;
c1 *= alpha;
c2 *= alpha;
c3 *= alpha;
let ccol0 = c;
let ccol1 = c.add(ldc);
let ccol2 = c.add(ldc * 2);
let ccol3 = c.add(ldc * 3);
if beta != 0.0 {
c0 += beta * *ccol0;
c1 += beta * *ccol1;
c2 += beta * *ccol2;
c3 += beta * *ccol3;
}
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
}
pub(crate) unsafe fn sgemm_pb_x4(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
let mut bcol0 = b;
let mut bcol1 = b.add(ldb);
let mut bcol2 = b.add(ldb * 2);
let mut bcol3 = b.add(ldb * 3);
let mut pb = pb;
for _ in 0..k {
*pb = *bcol0;
*pb.add(1) = *bcol1;
*pb.add(2) = *bcol2;
*pb.add(3) = *bcol3;
bcol0 = bcol0.add(1);
bcol1 = bcol1.add(1);
bcol2 = bcol2.add(1);
bcol3 = bcol3.add(1);
pb = pb.add(4);
}
}
// pub(crate) unsafe fn sgemm_pa_n(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
// use crate::kernel::params::single::MR;
// let mut a = a;
// let mut pa = pa;
// for p in 0..k {
// for j in 0..MR {
// *pa.add(j * k + p) = *a.add(p * lda + j);
// }
// }
// }
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
use crate::kernel::params::single::NR;
let mut b = b;
let mut pb = pb;
for _ in 0..k {
for j in 0..NR {
*pb.add(j) = *b.add(j);
}
pb = pb.add(NR);
b = b.add(ldb);
}
}
pub(crate) unsafe fn sgemm_pa_t(k: usize, a: *const f32, lda: usize, pa: *mut f32) {
use crate::kernel::params::single::MR;
let mut a = a;
let mut pa = pa;
for _ in 0..k {
for j in 0..MR {
*pa.add(j) = *a.add(j * lda);
}
a = a.add(1);
pa = pa.add(MR);
}
}

View File

@ -0,0 +1,2 @@
pub mod l3d;
pub mod l3s;

60
src/kernel/mod.rs Normal file
View File

@ -0,0 +1,60 @@
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
mod avx;
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub use avx::{l1d::*, l1s::*};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use avx::l3s::{
sgemm_pa_16x as sgemm_pa_n, sgemm_sup_16x1 as sgemm_sup0, sgemm_ukr_16x4 as sgemm_ukr, sgemm_sup0_t,
};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use avx::l3d::{
dgemm_pa_8x as dgemm_pa, dgemm_sup_8x1 as dgemm_sup0, dgemm_ukr_8x4 as dgemm_ukr,
};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use generic::l3s::{sgemm_pb_x4 as sgemm_pb_n, sgemm_pb_t, sgemm_pa_t, sgemm_sup_1x4 as sgemm_sup1, sgemm_sup1_t};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx"
))]
pub(crate) use generic::l3d::{dgemm_pb_x4 as dgemm_pb, dgemm_sup_1x4 as dgemm_sup1};
mod generic;
pub mod params {
pub mod single {
pub const MC: usize = 256;
pub const KC: usize = 128;
pub const NC: usize = 1024;
pub const MR: usize = 16;
pub const NR: usize = 4;
}
pub mod double {
pub const MC: usize = 256;
pub const KC: usize = 512;
pub const NC: usize = 4096;
pub const MR: usize = 8;
pub const NR: usize = 4;
}
}

498
src/main.rs Normal file
View File

@ -0,0 +1,498 @@
mod matrix;
mod sgemm;
mod aligned_alloc;
mod kernel;
extern crate blas;
extern crate openblas;
use self::matrix::{ConstMatrix, MutMatrix, ConstTransposedMatrix, MutTransposedMatrix};
pub unsafe fn sgemm(
transa: bool,
transb: bool,
transc: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
beta: f32,
c: *mut f32,
ldc: usize,
) {
match (transa, transb, transc) {
(false, false, false) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, false, true) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(false, true, false) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(false, true, true) => sgemm::sgemm(
m, n, k, alpha,
ConstMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, false, false) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, false, true) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
(true, true, false) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutMatrix::new(c, ldc)),
(true, true, true) => sgemm::sgemm(
m, n, k, alpha,
ConstTransposedMatrix::new(a, lda),
ConstTransposedMatrix::new(b, ldb),
beta,
MutTransposedMatrix::new(c, ldc)),
}
}
unsafe fn sgemm_ref_nn(
m: usize,
n: usize,
k: usize,
_alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
_beta: f32,
c: *mut f32,
ldc: usize,
) {
for j in 0..n {
for i in 0..m {
let mut ci = *c.add(i + j * ldc);
for p in 0..k {
ci += *a.add(i + p * lda) * *b.add(p + j * ldb);
}
*c.add(i + j * ldc) = ci;
}
}
}
unsafe fn sgemm_ref_nt(
m: usize,
n: usize,
k: usize,
_alpha: f32,
a: *const f32,
lda: usize,
b: *const f32,
ldb: usize,
_beta: f32,
c: *mut f32,
ldc: usize,
) {
for j in 0..n {
for i in 0..m {
let mut ci = *c.add(i + j * ldc);
for p in 0..k {
ci += *a.add(i + p * lda) * *b.add(j + p * ldb);
}
*c.add(i + j * ldc) = ci;
}
}
}
#[inline(never)]
pub fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
_bata: f32,
c: &mut [f32], ldc: usize) {
let mk = m * k;
let nk = n * k;
let mn = m * n;
let a = &a[0..mk];
let b = &b[0..nk];
let c = &mut c[0..mn];
for i_m in 0..m {
for i_k in 0..k {
let a_part = alpha * a[i_m * lda + i_k];
for i_n in 0..n {
c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
}
}
}
}
fn main2() {
use rand::Rng;
let mut rng = rand::thread_rng();
const LEN: usize = 511;
let (m, n, k) = (LEN, LEN, LEN);
let mut a = vec![0.5; m * k];
let mut a_t = vec![0.5; m * k];
let mut b = vec![0.5; n * k];
let mut b_t = vec![0.5; n * k];
let mut c_nn = vec![0.0; m * n];
let mut c_nt = vec![0.0; m * n];
let mut c_tn = vec![0.0; m * n];
let mut c_tt = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
for i in 0..m {
for j in 0..k {
let v = rng.gen();
a[i + j * m] = v;
a_t[j + i * m] = v;
}
}
for i in 0..n {
for j in 0..k {
let v = rng.gen();
b[i + j * n] = v;
b_t[j + i * n] = v;
}
}
// let time = std::time::Instant::now();
// unsafe {
// gemm_nn(
// LEN,
// LEN,
// LEN,
// 1.0,
// a.as_slice(),
// LEN,
// b.as_slice(),
// LEN,
// 1.0,
// cref1.as_mut_slice(),
// LEN,
// )
// }
// println!("Naive (mine) {}", time.elapsed().as_millis());
let time = std::time::Instant::now();
unsafe {
sgemm_ref_nn(
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
1.0,
cref.as_mut_ptr(),
m,
)
}
println!("Naive {}", time.elapsed().as_millis());
let time = std::time::Instant::now();
unsafe {
sgemm(
false,
false,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
1.0,
c_nn.as_mut_ptr(),
m,
);
}
println!("[NN] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_nn[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
false,
true,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b_t.as_ptr(),
n,
1.0,
c_nt.as_mut_ptr(),
m,
);
}
println!("[NT] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_nt[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
true,
false,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b.as_ptr(),
k,
1.0,
c_tn.as_mut_ptr(),
m,
);
}
println!("[TN] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_tn[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
let time = std::time::Instant::now();
unsafe {
sgemm(
true,
true,
false,
m,
n,
k,
1.0,
a_t.as_ptr(),
k,
b_t.as_ptr(),
n,
1.0,
c_tt.as_mut_ptr(),
m,
);
}
println!("[TT] Optimized {}", time.elapsed().as_millis());
for i in 0..LEN {
for j in 0..LEN {
let (a, b) = (c_tt[i + j * LEN], cref[i + j * LEN]);
assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
}
}
}
fn main() {
use rand::Rng;
let mut rng = rand::thread_rng();
const LEN: usize = 8192;
let (m, n, k) = (LEN, LEN, LEN);
let mut a = vec![0.5; m * k];
let mut b = vec![0.5; n * k];
let mut c = vec![0.0; m * n];
let mut cref1 = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
for i in 0..m {
for j in 0..k {
a[i + j * m] = rng.gen();
}
}
for i in 0..n {
for j in 0..k {
b[i + j * n] = rng.gen();
}
}
let time = std::time::Instant::now();
unsafe {
// blas::sgemm(
// b'N',
// b'N',
// m as i32,
// n as i32,
// k as i32,
// 1.0,
// a.as_slice(),
// m as i32,
// b.as_slice(),
// k as i32,
// 1.0,
// cref1.as_mut_slice(),
// m as i32,
// );
// gemm_nn(
// m,
// n,
// k,
// 1.0,
// a.as_slice(),
// m,
// b.as_slice(),
// m,
// 1.0,
// cref1.as_mut_slice(),
// m,
// )
}
println!("Matrixmultiply (mine) {}", time.elapsed().as_millis());
// println!("Naive (mine) {}", time.elapsed().as_millis());
// let time = std::time::Instant::now();
// unsafe {
// sgemm_ref_nn(
// m,
// n,
// k,
// 1.0,
// a.as_ptr(),
// m,
// b.as_ptr(),
// k,
// 1.0,
// c.as_mut_ptr(),
// m,
// )
// }
// println!("Naive {}", time.elapsed().as_millis());
let time = std::time::Instant::now();
unsafe {
sgemm(
false,
false,
false,
m,
n,
k,
1.0,
a.as_ptr(),
m,
b.as_ptr(),
k,
1.0,
cref.as_mut_ptr(),
m,
);
}
println!("Optimized {}", time.elapsed().as_millis());
// for i in 0..LEN {
// for j in 0..LEN {
// let (a, b) = (c[i + j * LEN], cref[i + j * LEN]);
// assert!(feq(a, b), "a != b, a[{}]={}, b[{}]={}", i, a, j, b);
// }
// }
}
fn feq(a: f32, b: f32) -> bool {
if a == b {
true
} else if a == 0.0 || b == 0.0 || (a.abs() + b.abs() < std::f32::MIN_POSITIVE) {
(a - b).abs() < std::f32::EPSILON * 10.0 * std::f32::MIN_POSITIVE
} else {
(a - b).abs() / (a.abs() + b.abs()) < std::f32::EPSILON * 10.0
}
}

283
src/matrix.rs Normal file
View File

@ -0,0 +1,283 @@
pub(crate) trait Matrix: Copy {
unsafe fn sub(&self, row: usize, col: usize) -> Self;
unsafe fn is_transposed(&self) -> bool;
unsafe fn stride(&self) -> usize;
unsafe fn get(&self, row: usize, col: usize) -> f32;
unsafe fn index(&self, row: usize, col: usize) -> *const f32;
unsafe fn row(&self, row: usize) -> *const f32;
unsafe fn col(&self, col: usize) -> *const f32;
}
pub(crate) trait MatrixMut: Matrix {
unsafe fn set(&self, row: usize, col: usize, val: f32);
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32;
unsafe fn row_mut(&self, row: usize) -> *mut f32;
unsafe fn col_mut(&self, col: usize) -> *mut f32;
}
#[derive(Copy, Clone)]
pub(crate) struct ConstMatrix {
pub stride: usize,
pub ptr: *const f32
}
unsafe impl Send for ConstMatrix {}
unsafe impl Sync for ConstMatrix {}
impl ConstMatrix {
pub fn new(ptr: *const f32, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for ConstMatrix {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
ptr: self.index(row, col),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
false
}
#[inline]
unsafe fn stride(&self) -> usize {
self.stride
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
*self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
self.ptr.add(col)
}
}
#[derive(Copy, Clone)]
pub(crate) struct ConstTransposedMatrix {
pub stride: usize,
pub ptr: *const f32
}
unsafe impl Send for ConstTransposedMatrix {}
unsafe impl Sync for ConstTransposedMatrix {}
impl ConstTransposedMatrix {
pub fn new(ptr: *const f32, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for ConstTransposedMatrix {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
ptr: self.index(row, col),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
true
}
#[inline]
unsafe fn stride(&self) -> usize {
self.stride
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
*self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
self.ptr.add(row)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
self.ptr.add(col * self.stride)
}
}
#[derive(Copy, Clone)]
pub(crate) struct MutMatrix {
pub stride: usize,
pub ptr: *mut f32
}
unsafe impl Send for MutMatrix {}
unsafe impl Sync for MutMatrix {}
impl MutMatrix {
pub fn new(ptr: *mut f32, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for MutMatrix {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
ptr: self.index_mut(row, col),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
false
}
#[inline]
unsafe fn stride(&self) -> usize {
self.stride
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
*self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
self.ptr.add(col)
}
}
impl MatrixMut for MutMatrix {
#[inline]
unsafe fn set(&self, row: usize, col: usize, value: f32) {
*self.ptr.add(row * self.stride + col) = value;
}
#[inline]
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
self.ptr.add(row * self.stride + col)
}
#[inline]
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
self.ptr.add(row * self.stride)
}
#[inline]
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
self.ptr.add(col)
}
}
#[derive(Copy, Clone)]
pub(crate) struct MutTransposedMatrix {
pub stride: usize,
pub ptr: *mut f32
}
unsafe impl Send for MutTransposedMatrix {}
unsafe impl Sync for MutTransposedMatrix {}
impl MutTransposedMatrix {
pub fn new(ptr: *mut f32, stride: usize) -> Self {
Self { ptr, stride }
}
}
impl Matrix for MutTransposedMatrix {
#[inline]
unsafe fn sub(&self, row: usize, col: usize) -> Self {
Self {
ptr: self.index_mut(row, col),
stride: self.stride
}
}
#[inline]
unsafe fn is_transposed(&self) -> bool {
true
}
#[inline]
unsafe fn stride(&self) -> usize {
self.stride
}
#[inline]
unsafe fn get(&self, row: usize, col: usize) -> f32 {
*self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn index(&self, row: usize, col: usize) -> *const f32 {
self.ptr.add(col * self.stride + row) as *const f32
}
#[inline]
unsafe fn row(&self, row: usize) -> *const f32 {
self.ptr.add(row) as *const f32
}
#[inline]
unsafe fn col(&self, col: usize) -> *const f32 {
self.ptr.add(col * self.stride) as *const f32
}
}
impl MatrixMut for MutTransposedMatrix {
#[inline]
unsafe fn set(&self, row: usize, col: usize, value: f32) {
*self.ptr.add(col * self.stride + row) = value;
}
#[inline]
unsafe fn index_mut(&self, row: usize, col: usize) -> *mut f32 {
self.ptr.add(col * self.stride + row)
}
#[inline]
unsafe fn row_mut(&self, row: usize) -> *mut f32 {
self.ptr.add(row)
}
#[inline]
unsafe fn col_mut(&self, col: usize) -> *mut f32 {
self.ptr.add(col * self.stride)
}
}

273
src/sgemm.rs Normal file
View File

@ -0,0 +1,273 @@
use crate::aligned_alloc;
use crate::matrix::{Matrix, MatrixMut, MutMatrix};
use crate::kernel::params::single::*;
use crate::kernel;
pub(crate) unsafe fn sgemm<A, B, C>(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C
)
where A: Matrix,
B: Matrix,
C: MatrixMut,
{
let packed_a = aligned_alloc::Alloc::new(MC * KC * std::mem::size_of::<f32>());
let packed_b = aligned_alloc::Alloc::new(KC * NC * std::mem::size_of::<f32>());
for j in (0..n).step_by(NC) {
let j_b = std::cmp::min(n - j, NC);
for p in (0..k).step_by(KC) {
let p_b = std::cmp::min(k - p, KC);
for i in (0..m).step_by(MC) {
let i_b = std::cmp::min(m - i, MC);
let pa = MutMatrix::new(packed_a.ptr_f32(), p_b);
let pb = MutMatrix::new(packed_b.ptr_f32(), p_b);
inner_kernel(
i_b,
j_b,
p_b,
alpha,
a.sub(p, i),
b.sub(j, p),
beta,
c.sub(j, i),
pa,
pb,
i == 0
);
}
}
}
}
// +----------------------+
// |A : : : |B |
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// + - - - - - - - - - ---+
// | : : : | |
// | : : : | |
// +----------------------+
// |C | | | |D |
// +----------------------+
#[inline]
unsafe fn sgemm_pa<A: Matrix>(a: A, pa: MutMatrix, i: usize) {
if a.is_transposed() {
kernel::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
} else {
kernel::sgemm_pa_n(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
}
}
#[inline]
unsafe fn sgemm_pb<B: Matrix>(b: B, pb: MutMatrix, j: usize) {
if b.is_transposed() {
kernel::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
} else {
kernel::sgemm_pb_n(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
}
}
#[inline]
unsafe fn sgemm_ukr<C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
pa: MutMatrix,
pb: MutMatrix,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
kernel::sgemm_ukr(
pa.stride,
alpha,
pa.row(i),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
}
}
#[inline]
unsafe fn sgemm_sup1<A: Matrix, C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
a: A,
pb: MutMatrix,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
if a.is_transposed() {
kernel::sgemm_sup1_t(
pb.stride,
alpha,
a.col(i),
a.stride(),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
} else {
kernel::sgemm_sup1(
pb.stride,
alpha,
a.col(i),
a.stride(),
pb.row(j),
beta,
c.index_mut(j, i),
c.stride())
}
}
}
#[inline]
unsafe fn sgemm_sup0<B: Matrix, C: MatrixMut>(
i: usize, j: usize,
alpha: f32,
pa: MutMatrix,
b: B,
beta: f32,
c: C)
{
if c.is_transposed() {
unimplemented!()
} else {
if b.is_transposed() {
kernel::sgemm_sup0_t(
pa.stride,
alpha,
pa.row(i),
b.row(j),
b.stride(),
beta,
c.index_mut(j, i))
} else {
kernel::sgemm_sup0(
pa.stride,
alpha,
pa.row(i),
b.row(j),
beta,
c.index_mut(j, i))
}
}
}
unsafe fn inner_kernel<A, B, C>(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: A,
b: B,
beta: f32,
c: C,
pa: MutMatrix,
pb: MutMatrix,
first_time: bool,
)
where A: Matrix,
B: Matrix,
C: MatrixMut,
{
let n_left = n % NR;
let n_main = n - n_left;
let m_left = m % MR;
let m_main = m - m_left;
if first_time {
for j in (0..n_main).step_by(NR) {
sgemm_pb(b, pb, j);
}
}
for i in (0..m_main).step_by(MR) {
sgemm_pa(a, pa, i);
};
for j in (0..n_main).step_by(NR) {
// Section A
for i in (0..m_main).step_by(MR) {
sgemm_ukr(
i,
j,
alpha,
pa,
pb,
beta,
c,
);
}
// Section B
for i in m_main..m {
sgemm_sup1(
i,
j,
alpha,
a,
pb,
beta,
c,
);
}
}
for j in n_main..n {
// Section C
for i in (0..m_main).step_by(MR) {
sgemm_sup0(
i,
j,
alpha,
pa,
b,
beta,
c,
);
}
// Section D
for i in m_main..m {
let mut elem = 0.0;
for p in 0..k {
elem += a.get(p, i) * b.get(j, p);
}
elem *= alpha;
if beta != 0.0 {
elem += beta * c.get(j, i);
}
c.set(j, i, elem);
}
}
}

4
src/stats.sh Normal file
View File

@ -0,0 +1,4 @@
#!/bin/bash
cargo build --release;
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm

0
src/utils/mod.rs Normal file
View File

5
stats.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
# export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,-avx,-avx2"
export RUSTFLAGS="-C target-feature=-mmx,-sse,-sse2,-sse3,-ssse3,-sse4.1,-sse4.2,-3dnow,-3dnowa,+avx,+avx2"
cargo clean && cargo build --release
perf stat -B -e cache-references,cache-misses,cycles,instructions,branches,branch-misses,faults,migrations,context-switches ./target/release/gemm