Fixed bug with calculation

This commit is contained in:
Andrey Tkachenko 2019-12-05 17:21:05 +04:00
parent aaf8ec3203
commit a699078f73
16 changed files with 471 additions and 91 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
/target
/demo/target
**/*.rs.bk

314
demo/Cargo.lock generated Normal file
View File

@ -0,0 +1,314 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "autocfg"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "blas"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "blas-sys"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "c2-chacha"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "crossbeam-deque"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-epoch"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-queue"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crossbeam-utils"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "demo"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"gemm 0.1.0",
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "either"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "gemm"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
"openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "getrandom"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "hermit-abi"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "libc"
version = "0.2.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "memoffset"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-complex"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-traits"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num_cpus"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "openblas"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ppv-lite86"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rand"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_chacha"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon-core"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rustc_version"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "scopeguard"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "semver"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "semver-parser"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "threadpool"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasi"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[metadata]
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
"checksum crossbeam-queue 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "dfd6515864a82d2f877b42813d4553292c6659498c9a2aa31bab5a15243c2700"
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
"checksum crunchy 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558"
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
"checksum openblas 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f005eb075fb562358060245c5d00a877757b4ddec8e8cd69cab96d8cc55998e8"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rayon 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "43739f8831493b276363637423d3622d4bd6394ab6f0a9c4a552e208aeb7fddd"
"checksum rayon-core 1.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f8bf17de6f23b05473c437eb958b9c850bfc8af0961fe17b4cc92d5a627b4791"
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
"checksum threadpool 1.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865"
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"

14
demo/Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "demo"
version = "0.1.0"
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
edition = "2018"
build = "build.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
gemm = {path="../"}
openblas = "0.0.1"
blas = "0.20.0"
rand = "0.7.2"

7
demo/build.rs Normal file
View File

@ -0,0 +1,7 @@
fn main() {
// Always linking openblas
// Compiling blas just for testing is tedious -- install it on your system
// and run this.
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rustc-link-lib={}=openblas", "dylib");
}

View File

@ -1,16 +1,5 @@
mod matrix;
mod gemm;
mod sgemm;
mod aligned_alloc;
mod kernel;
mod dim;
mod executor;
extern crate blas;
extern crate openblas;
use crate::sgemm::sgemm;
extern crate blas;
fn main() {
use rand::Rng;
@ -97,11 +86,12 @@ fn main() {
let time = std::time::Instant::now();
unsafe {
sgemm(
&executor::RayonExecutor,
gemm::sgemm(
&gemm::executor::DefaultExecutor,
false,
false,
false,
m,
n,
k,

View File

@ -7,7 +7,7 @@ pub trait Executor {
f: F,
);
fn synchronize(&self);
fn synchronize(&self) {}
}
pub struct DefaultExecutor;
@ -24,10 +24,6 @@ impl Executor for DefaultExecutor {
f(i);
}
}
fn synchronize(&self) {
}
}
pub struct RayonExecutor;
@ -48,10 +44,6 @@ impl Executor for RayonExecutor {
.map(|x| x * step)
.for_each(|x|f(x));
}
fn synchronize(&self) {
}
}

View File

@ -1,6 +1,6 @@
use crate::aligned_alloc;
use crate::matrix::{Number, Matrix, MatrixMut, MutMatrix, ConstMatrix, ConstTransposedMatrix, MutTransposedMatrix};
use crate::kernel::params::single::*;
use crate::kernel::params::single::{NC, MC, KC};
use crate::kernel;
use crate::kernel::GemmKernel;
use crate::dim::Dim;
@ -140,19 +140,25 @@ where E: Executor,
}
}
//
// | MR |
// +----------------------+
// |TL : : : |TR|
// | : : : | |
// + - - - - - - - - - ---+----
// | : : : | | NR
// | : : : | |
// + - - - - - - - - - ---+----
// | : : : | |
// | : : : | |
// +----------------------+
// |BL | | | |BR|
// +----------------------+
// +-----------------------------+----+
// | : | |
// | TL : | TR |
// | : | |
// + - - - - - - -:- - - - - - - - - -+ ----
// | : | |
// | : | | NR
// | : | |
// + - - - - - - -:- - - - - - - - - -+ ----
// | : | |
// | : | |
// | : | |
// +-----------------------------+----+
// | BL | | BR |
// +-----------------------------+----+
//
unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
e: &E,
m: usize,
@ -177,25 +183,25 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
C: MatrixMut<F>,
{
let n_left = n % NR;
let n_left = n % NR::DIM;
let n_main = n - n_left;
let m_left = m % MR;
let m_left = m % MR::DIM;
let m_main = m - m_left;
if first_time {
e.execute(0, n_main, NR, move |j|
K::pack_row_b(b, pb, j));
e.execute(0, n_main, NR::DIM, move |j|
K::pack_row_b(b, pb));
}
e.execute(0, m_main, MR, move |i|
K::pack_row_a(a, pa, i));
e.execute(0, m_main, MR::DIM, move |i|
K::pack_row_a(a.sub_col(i), pa.sub_row(i)));
e.synchronize();
e.execute(0, n_main, NR, move |j| {
e.execute(0, n_main, NR::DIM, move |j| {
// Section TL
for i in (0..m_main).step_by(MR) {
for i in (0..m_main).step_by(MR::DIM) {
K::main_tl(alpha,
pa.sub_row(i),
pb.sub_row(j),
@ -214,9 +220,11 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
}
});
e.execute(n_main, n, 1, move |j| {
e.execute(0, n_left * m_main, MR::DIM, move |ji| {
let j = n_main + ji / m_main;
let i = ji % m_main;
// Section BL
for i in (0..m_main).step_by(MR) {
K::sup_bl(
alpha,
pa.sub_row(i),
@ -224,9 +232,10 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
beta,
c.sub(j, i)
);
}
});
// Section BR
for j in n_main..n {
for i in m_main..m {
K::sup_br(
k,
@ -236,7 +245,7 @@ unsafe fn inner_kernel<E, F, K, MR, NR, A, B, C>(
beta,
c.sub(j, i))
}
});
};
e.synchronize();
}

View File

@ -137,6 +137,18 @@ pub(crate) unsafe fn sgemm_ukr_16x8<C: MatrixMut<f32>>(
// let ccol6 = c.row_mut(6);
// let ccol7 = c.add(ldc * 7);
mt00 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt00);
mt01 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt01);
mt02 = _mm256_add_ps(_mm256_loadu_ps(ccol2), mt02);
mt03 = _mm256_add_ps(_mm256_loadu_ps(ccol3), mt03);
mt04 = _mm256_add_ps(_mm256_loadu_ps(ccol4), mt04);
mt10 = _mm256_add_ps(_mm256_loadu_ps(ccol0.add(8)), mt10);
mt11 = _mm256_add_ps(_mm256_loadu_ps(ccol1.add(8)), mt11);
mt12 = _mm256_add_ps(_mm256_loadu_ps(ccol2.add(8)), mt12);
mt13 = _mm256_add_ps(_mm256_loadu_ps(ccol3.add(8)), mt13);
mt14 = _mm256_add_ps(_mm256_loadu_ps(ccol4.add(8)), mt14);
// if beta != 0.0 {
// let beta = _mm256_broadcast_ss(&beta);
@ -210,6 +222,9 @@ pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
// mt0 = _mm256_mul_ps(alpha, mt0);
// mt1 = _mm256_mul_ps(alpha, mt1);
let ccol0 = c.ptr_mut();
let ccol1 = c.ptr_mut().add(8);
// if beta != 0.0 {
// let beta = _mm256_broadcast_ss(&beta);
@ -217,8 +232,11 @@ pub(crate) unsafe fn sgemm_sup_16x1<B: Matrix<f32>, C: MatrixMut<f32>>(
// mt1 = fmadd_ps(beta, _mm256_loadu_ps(c.add(8)), mt1);
// }
_mm256_storeu_ps(c.ptr_mut(), mt0);
_mm256_storeu_ps(c.col_mut(8), mt1);
mt0 = _mm256_add_ps(_mm256_loadu_ps(ccol0), mt0);
mt1 = _mm256_add_ps(_mm256_loadu_ps(ccol1), mt1);
_mm256_storeu_ps(ccol0, mt0);
_mm256_storeu_ps(ccol1, mt1);
}
pub(crate) unsafe fn sgemm_pa_16x(k: usize, a: *const f32, lda: usize, pa: *mut f32) {

View File

@ -63,17 +63,17 @@ impl<I> GemmKernel<f32, A16, A5> for AvxKernel<f32, I>
where I: GemmKernel<f32, A16, A5>
{
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
if a.is_transposed() {
I::pack_row_a(a, pa, i);
I::pack_row_a(a, pa);
} else {
self::l3s::sgemm_pa_16x(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
self::l3s::sgemm_pa_16x(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
}
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
I::pack_row_b(b, pb, j);
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
I::pack_row_b(b, pb);
}
#[inline]

View File

@ -67,11 +67,11 @@ pub(crate) unsafe fn sgemm_sup_1x8<A: Matrix<f32>, C: MatrixMut<f32>>(
// // c7 += beta * *ccol7;
// }
*ccol0 = c0;
*ccol1 = c1;
*ccol2 = c2;
*ccol3 = c3;
*ccol4 = c4;
*ccol0 += c0;
*ccol1 += c1;
*ccol2 += c2;
*ccol3 += c3;
*ccol4 += c4;
// *ccol5 = c5;
// *ccol6 = c6;
// *ccol7 = c7;
@ -113,7 +113,6 @@ pub(crate) unsafe fn sgemm_pb_x8(k: usize, b: *const f32, ldb: usize, pb: *mut f
}
pub(crate) unsafe fn sgemm_pb_t(k: usize, b: *const f32, ldb: usize, pb: *mut f32) {
use crate::kernel::params::single::NR;
let mut b = b;
let mut pb = pb;

View File

@ -61,27 +61,27 @@ impl GemmKernelSup<f32> for GenericKernel {
elem += beta * *c.ptr();
}
*c.ptr_mut() = elem;
*c.ptr_mut() += elem;
}
}
impl GemmKernel<f32, A16, A5> for GenericKernel {
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
if a.is_transposed() {
self::l3s::sgemm_pa_t(pa.stride, a.col(i), a.stride(), pa.row_mut(i));
self::l3s::sgemm_pa_t(pa.stride, a.ptr(), a.stride(), pa.ptr_mut());
} else {
unimplemented!()
}
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
if b.is_transposed() {
self::l3s::sgemm_pb_t(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
self::l3s::sgemm_pb_t(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
} else {
self::l3s::sgemm_pb_x8(pb.stride, b.row(j), b.stride(), pb.row_mut(j));
self::l3s::sgemm_pb_x8(pb.stride, b.ptr(), b.stride(), pb.ptr_mut());
}
}

View File

@ -71,8 +71,8 @@ pub trait GemmKernelSupNr<F: Number, NR: Dim> {
}
pub trait GemmKernel<F: Number, MR: Dim, NR: Dim>: GemmKernelSupMr<F, MR> + GemmKernelSupNr<F, NR> + GemmKernelSup<F> {
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>, i: usize);
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>, j: usize);
unsafe fn pack_row_a<A: Matrix<F>>(a: A, pa: MutMatrix<F>);
unsafe fn pack_row_b<B: Matrix<F>>(b: B, pb: MutMatrix<F>);
unsafe fn main_tl<C: MatrixMut<F>>(
alpha: F,

View File

@ -64,13 +64,13 @@ impl<I> GemmKernel<f32, A16, A5> for SseKernel<f32, I>
where I: GemmKernel<f32, A16, A5>
{
#[inline]
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>, i: usize) {
I::pack_row_a(a, pa, i);
unsafe fn pack_row_a<A: Matrix<f32>>(a: A, pa: MutMatrix<f32>) {
I::pack_row_a(a, pa);
}
#[inline]
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>, j: usize) {
I::pack_row_b(b, pb, j);
unsafe fn pack_row_b<B: Matrix<f32>>(b: B, pb: MutMatrix<f32>) {
I::pack_row_b(b, pb);
}
#[inline]

View File

@ -1,5 +1,5 @@
pub trait Number: Copy + Send + Sync + 'static {}
pub trait Number: core::fmt::Display + Copy + Send + Sync + 'static {}
impl Number for f32 {}
impl Number for f64 {}

View File

@ -1,8 +1,8 @@
use rand::Rng;
const M_LEN: usize = 8192;
const N_LEN: usize = 8192;
const K_LEN: usize = 8192;
const M_LEN: usize = 17;
const N_LEN: usize = 6;
const K_LEN: usize = 4;
fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let mut rng = rand::thread_rng();
@ -15,19 +15,23 @@ fn make_matrices() -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let mut b = vec![0.0; n * k];
let mut b_t = vec![0.0; n * k];
let mut counter = 0;
for row in 0..k {
for col in 0..m {
let v = rng.gen();
let v = counter as f32;
a[row * m + col] = v;
a_t[col * k + row] = v;
counter += 1;
}
}
let mut counter = 0;
for row in 0..n {
for col in 0..k {
let v = rng.gen();
let v = counter as f32;
b[row * k + col] = v;
b_t[col * n + row] = v;
counter += 1;
}
}
@ -80,6 +84,17 @@ fn test_sgemm_nn() {
let mut c = vec![0.0; m * n];
let mut cref = vec![0.0; m * n];
for row in 0..N_LEN {
for col in 0..K_LEN {
let index = row * K_LEN + col;
print!("{:>5} ", b[index]);
}
println!();
}
println!();
unsafe {
blas::sgemm(
b'N',
@ -118,6 +133,27 @@ fn test_sgemm_nn() {
);
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
print!("{:>5} ", c[index]);
}
println!();
}
println!();
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
print!("{:>5} ", cref[index]);
}
println!();
}
for row in 0..N_LEN {
for col in 0..M_LEN {
let index = row * M_LEN + col;
@ -127,8 +163,8 @@ fn test_sgemm_nn() {
assert!(feq(a, b), "a != b, a[{}, {}]={}, b[{}, {}]={}", row, col, a, row, col, b);
}
}
}
}
#[test]
fn test_sgemm_nt() {

View File