diff --git a/Cargo.toml b/Cargo.toml index 5554596..a5fe4b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cblas-src" -version = "0.1.1" +version = "0.1.2" authors = ["Andrey Tkachenko "] edition = "2018" build="build.rs" diff --git a/build.rs b/build.rs index 3493412..90f5f88 100644 --- a/build.rs +++ b/build.rs @@ -1,15 +1,14 @@ -use std::{env, fs, io}; -use std::path::{Path, PathBuf}; -use flate2::read::GzDecoder; -use tar::Archive; use curl::easy::Easy; +use flate2::read::GzDecoder; use std::io::Write; - +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::{env, fs, io}; +use tar::Archive; const CBLAS_SRC: &str = "http://www.netlib.org/blas/blast-forum/cblas.tgz"; - -fn download>(source_url: &str, target_file: P) -> anyhow::Result<()> { +fn download>(source_url: &str, target_file: P) -> anyhow::Result<()> { let f = fs::File::create(&target_file)?; let mut writer = io::BufWriter::new(f); let mut easy = Easy::new(); @@ -36,7 +35,7 @@ fn extract, P2: AsRef>(filename: P1, outpath: P2) -> anyho let tar = GzDecoder::new(file); let mut archive = Archive::new(tar); archive.unpack(outpath.as_ref())?; - + Ok(()) } @@ -49,13 +48,60 @@ fn main() { download(CBLAS_SRC, &cblas_name).unwrap(); extract(cblas_name, &out_path).unwrap(); } - + + + Command::new("gfortran") + .current_dir(cblas_dir.join("src")) + .arg("-O3") + .arg("-c").arg(cblas_dir.join("src/cdotcsub.f")) + .arg("-c").arg(cblas_dir.join("src/cdotusub.f")) + .arg("-c").arg(cblas_dir.join("src/dasumsub.f")) + .arg("-c").arg(cblas_dir.join("src/ddotsub.f")) + .arg("-c").arg(cblas_dir.join("src/dnrm2sub.f")) + .arg("-c").arg(cblas_dir.join("src/dsdotsub.f")) + .arg("-c").arg(cblas_dir.join("src/dzasumsub.f")) + .arg("-c").arg(cblas_dir.join("src/dznrm2sub.f")) + .arg("-c").arg(cblas_dir.join("src/icamaxsub.f")) + .arg("-c").arg(cblas_dir.join("src/idamaxsub.f")) + .arg("-c").arg(cblas_dir.join("src/isamaxsub.f")) + .arg("-c").arg(cblas_dir.join("src/izamaxsub.f")) + .arg("-c").arg(cblas_dir.join("src/sasumsub.f")) + .arg("-c").arg(cblas_dir.join("src/scasumsub.f")) + .arg("-c").arg(cblas_dir.join("src/scnrm2sub.f")) + .arg("-c").arg(cblas_dir.join("src/sdotsub.f")) + .arg("-c").arg(cblas_dir.join("src/sdsdotsub.f")) + .arg("-c").arg(cblas_dir.join("src/snrm2sub.f")) + .arg("-c").arg(cblas_dir.join("src/zdotcsub.f")) + .arg("-c").arg(cblas_dir.join("src/zdotusub.f")) + .status() + .expect("fortran failed"); + cc::Build::new() .warnings(false) .extra_warnings(false) .static_flag(true) .define("ADD_", None) .include(cblas_dir.join("include")) + .object(cblas_dir.join("src/cdotcsub.o")) + .object(cblas_dir.join("src/cdotusub.o")) + .object(cblas_dir.join("src/dasumsub.o")) + .object(cblas_dir.join("src/ddotsub.o")) + .object(cblas_dir.join("src/dnrm2sub.o")) + .object(cblas_dir.join("src/dsdotsub.o")) + .object(cblas_dir.join("src/dzasumsub.o")) + .object(cblas_dir.join("src/dznrm2sub.o")) + .object(cblas_dir.join("src/icamaxsub.o")) + .object(cblas_dir.join("src/idamaxsub.o")) + .object(cblas_dir.join("src/isamaxsub.o")) + .object(cblas_dir.join("src/izamaxsub.o")) + .object(cblas_dir.join("src/sasumsub.o")) + .object(cblas_dir.join("src/scasumsub.o")) + .object(cblas_dir.join("src/scnrm2sub.o")) + .object(cblas_dir.join("src/sdotsub.o")) + .object(cblas_dir.join("src/sdsdotsub.o")) + .object(cblas_dir.join("src/snrm2sub.o")) + .object(cblas_dir.join("src/zdotcsub.o")) + .object(cblas_dir.join("src/zdotusub.o")) .file(cblas_dir.join("src/cblas_caxpy.c")) .file(cblas_dir.join("src/cblas_ccopy.c")) .file(cblas_dir.join("src/cblas_cdotc_sub.c")) @@ -200,8 +246,6 @@ fn main() { .file(cblas_dir.join("src/cblas_ztrmv.c")) .file(cblas_dir.join("src/cblas_ztrsm.c")) .file(cblas_dir.join("src/cblas_ztrsv.c")) - - // TODO build fortran files aswell - + // TODO build fortran files aswell .compile("libcblas"); -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 6768a93..99bf7c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ + #[cfg(test)] mod tests { extern crate openblas_src; @@ -12,6 +13,7 @@ mod tests { 2.0, 5.0, 3.0, 6.0, ]; + let b = vec![ 1.0, 5.0, 9.0, 2.0, 6.0, 10.0, @@ -24,11 +26,17 @@ mod tests { 0.0, 7.0, 4.0, 2.0, ]; + + let sdot_res = unsafe { + ddot(6, &a, 1, &a, 1) + }; unsafe { dgemm(Layout::ColumnMajor, Transpose::None, Transpose::None, m, n, k, 1.0, &a, m, &b, k, 1.0, &mut c, m); } + + assert_eq!(sdot_res, 91.0); assert!( c == vec![