Add CI
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Andrey Tkachenko 2022-03-17 15:32:55 +04:00
parent 765a1eccfa
commit 2fa6098f65
40 changed files with 2054 additions and 1659 deletions

14
.drone.yml Normal file
View File

@ -0,0 +1,14 @@
kind: pipeline
name: default
steps:
- name: build
image: rustlang/rust:nightly
commands:
- cargo build --verbose --all
- name: fmt-check
image: rustlang/rust:nightly
commands:
- rustup component add rustfmt
- cargo fmt --all -- --check

417
Cargo.lock generated
View File

@ -1,514 +1,363 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "autocfg"
version = "0.1.4"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "backtrace"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "backtrace-sys"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
]
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "blas"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
dependencies = [
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"blas-sys",
"libc",
"num-complex",
]
[[package]]
name = "blas-sys"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
dependencies = [
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"libc",
]
[[package]]
name = "bumpalo"
version = "2.5.0"
version = "3.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899"
[[package]]
name = "byteorder"
version = "1.3.2"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "c2-chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "cblas"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
dependencies = [
"cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"cblas-sys",
"libc",
"num-complex",
]
[[package]]
name = "cblas-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
dependencies = [
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"libc",
]
[[package]]
name = "cc"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "cfg-if"
version = "0.1.9"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "failure"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)",
"failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "failure_derive"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "getrandom"
version = "0.1.6"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "heck"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "js-sys"
version = "0.3.25"
version = "0.3.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04"
dependencies = [
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.3.0"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"spin 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.59"
version = "0.2.120"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad5c14e80759d0939d013e6ca49930e59fc53dd8e5009132f76240c179380c09"
[[package]]
name = "log"
version = "0.4.7"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
dependencies = [
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if",
]
[[package]]
name = "memchr"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "mnist"
version = "0.4.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fd2e5236a5f13d41d4f6eba34f535d037cb9fc4b244896886a7ebe0764142a5"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "nom"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder",
]
[[package]]
name = "num-complex"
version = "0.2.3"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
dependencies = [
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.8"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
dependencies = [
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg",
]
[[package]]
name = "openblas-src"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
[[package]]
name = "ppv-lite86"
version = "0.2.5"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro2"
version = "0.4.30"
version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029"
dependencies = [
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-xid",
]
[[package]]
name = "quote"
version = "0.6.13"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "864d3e96a899863136fc6e99f3d7cae289dafe43bf2c5ac19b70df7210c0a145"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.7.0"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"getrandom",
"libc",
"rand_chacha",
"rand_core",
"rand_hc",
]
[[package]]
name = "rand_chacha"
version = "0.2.0"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.5.0"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.2.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96977acbdd3a6576fb1d27391900035bf3863d4a16422973a409b488cf29ffb2"
dependencies = [
"rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core",
]
[[package]]
name = "rustc-demangle"
version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "sourcefile"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "spin"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "syn"
version = "0.15.39"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea297be220d52398dcc07ce15a209fce436d361735ac1db700cab3b6cdfb9f54"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "synstructure"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "unicode-segmentation"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "unicode-xid"
version = "0.1.0"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "version_check"
version = "0.1.5"
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasm-bindgen"
version = "0.2.48"
version = "0.2.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06"
dependencies = [
"wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.48"
version = "0.2.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca"
dependencies = [
"bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"bumpalo",
"lazy_static",
"log",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.48"
version = "0.2.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01"
dependencies = [
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.48"
version = "0.2.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.48"
version = "0.2.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasm-bindgen-webidl"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2"
[[package]]
name = "web-sys"
version = "0.3.25"
version = "0.3.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb"
dependencies = [
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "weedle"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "yarnn"
version = "0.1.0"
dependencies = [
"rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_distr 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand",
"rand_distr",
]
[[package]]
name = "yarnn-example-mnist"
version = "0.1.0"
dependencies = [
"mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"yarnn 0.1.0",
"yarnn-model-mnist 0.1.0",
"yarnn-native-blas 0.1.0",
"mnist",
"yarnn",
"yarnn-model-mnist",
"yarnn-native-blas",
]
[[package]]
name = "yarnn-example-mnist-wasm"
version = "0.1.0"
dependencies = [
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"yarnn 0.1.0",
"js-sys",
"wasm-bindgen",
"web-sys",
"yarnn",
]
[[package]]
name = "yarnn-example-vgg16-demo"
version = "0.1.0"
dependencies = [
"yarnn 0.1.0",
"yarnn-model-vgg16 0.1.0",
"yarnn",
"yarnn-model-vgg16",
]
[[package]]
name = "yarnn-model-mnist"
version = "0.1.0"
dependencies = [
"yarnn 0.1.0",
"yarnn",
]
[[package]]
name = "yarnn-model-vgg16"
version = "0.1.0"
dependencies = [
"yarnn 0.1.0",
"yarnn",
]
[[package]]
name = "yarnn-native-blas"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"yarnn 0.1.0",
"blas",
"cblas",
"openblas-src",
"yarnn",
]
[metadata]
"checksum autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "0e49efa51329a5fd37e7c79db4621af617cd4e3e5bc224939808d076077077bf"
"checksum backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)" = "18b50f5258d1a9ad8396d2d345827875de4261b158124d4c819d9b351454fae5"
"checksum backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)" = "5b3a000b9c543553af61bc01cbfc403b04b5caa9e421033866f2e98061eb3e61"
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
"checksum bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2cd43d82f27d68911e6ee11ee791fb248f138f5d69424dc02e098d4f152b0b05"
"checksum byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5"
"checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101"
"checksum cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
"checksum cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
"checksum cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "39f75544d7bbaf57560d2168f28fd649ff9c76153874db88bdbdfd839b1a7e7d"
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
"checksum getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "e65cce4e5084b14874c4e7097f38cab54f47ee554f9194673456ea379dcc4c55"
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
"checksum js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "da3ea71161651a4cd97d999b2da139109c537b15ab33abc8ae4ead38deac8a03"
"checksum lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bc5729f27f159ddd61f4df6228e827e86643d4d3e7c32183cb30a1c08f604a14"
"checksum libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)" = "3262021842bf00fe07dbd6cf34ff25c99d7a7ebef8deea84db72be3ea3bb0aff"
"checksum log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "c275b6ad54070ac2d665eef9197db647b32239c9d244bfb6f041a766d00da5b3"
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
"checksum mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "25f19bfda80095b4294000bbb50506f028149ed0ddb7fabf46ebb673b91626bc"
"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
"checksum num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32"
"checksum openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
"checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b"
"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
"checksum rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d47eab0e83d9693d40f825f86948aa16eff6750ead4bdffc4ab95b8b3a7f052c"
"checksum rand_chacha 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e193067942ef6f485a349a113329140d0ab9e2168ce92274499bb0e9a4190d9d"
"checksum rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "615e683324e75af5d43d8f7a39ffe3ee4a9dc42c5c701167a71dc59c3a493aca"
"checksum rand_distr 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c37e54d811c5a51195156444e298a98ba57df6dca1e511f34d4791993883b921"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "a7f4dccf6f4891ebcc0c39f9b6eb1a83b9bf5d747cb439ec6fba4f3b977038af"
"checksum sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4bf77cb82ba8453b42b6ae1d692e4cdc92f9a47beaf89a847c8be83f4e328ad3"
"checksum spin 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44363f6f51401c34e7be73db0db371c04705d35efbe9f7d6082e03a921a32c55"
"checksum syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)" = "b4d960b829a55e56db167e861ddb43602c003c7be0bee1d345021703fac2fb7c"
"checksum synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "02353edf96d6e4dc81aea2d8490a7e9db177bf8acb0e951c24940bf866cb313f"
"checksum unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1967f4cdfc355b37fd76d2a954fb2ed3871034eb4f26d60537d88795cfc332a9"
"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
"checksum wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "4de97fa1806bb1a99904216f6ac5e0c050dc4f8c676dc98775047c38e5c01b55"
"checksum wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "5d82c170ef9f5b2c63ad4460dfcee93f3ec04a9a36a4cc20bc973c39e59ab8e3"
"checksum wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "f07d50f74bf7a738304f6b8157f4a581e1512cd9e9cdb5baad8c31bbe8ffd81d"
"checksum wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "95cf8fe77e45ba5f91bc8f3da0c3aa5d464b3d8ed85d84f4d4c7cc106436b1d7"
"checksum wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c2d4d4756b2e46d3a5422e06277d02e4d3e1d62d138b76a4c681e925743623"
"checksum wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "24e47859b4eba3d3b9a5c2c299f9d6f8d0b613671315f6f0c5c7f835e524b36a"
"checksum web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "86d515d2f713d3a6ab198031d2181b7540f8e319e4637ec2d4a41a208335ef29"
"checksum weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3bb43f70885151e629e2a19ce9e50bd730fd436cfd4b666894c9ce4de9141164"

View File

@ -1,10 +1,10 @@
use yarnn::prelude::*;
use yarnn::native::{Native, NativeTensor};
use yarnn_model_mnist::*;
use yarnn::losses::CrossEntropyLoss;
use yarnn::optimizers::Adam;
use yarnn_native_blas::NativeBlas;
use mnist::{Mnist, MnistBuilder};
use yarnn::losses::CrossEntropyLoss;
use yarnn::native::{Native, NativeTensor};
use yarnn::optimizers::Adam;
use yarnn::prelude::*;
use yarnn_model_mnist::*;
use yarnn_native_blas::NativeBlas;
fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -> f32 {
let mut vec = vec![0.0; pred.shape().size()];
@ -14,7 +14,7 @@ fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -
let mut total = 0;
for (x, &y) in vec.chunks(10).zip(targets.iter()) {
let x = &x[0 .. 10];
let x = &x[0..10];
let mut max = 0;
let mut max_value = 0.0;
@ -33,7 +33,7 @@ fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -
total += 1;
}
(positives as f32) / (total as f32)
(positives as f32) / (total as f32)
}
fn main() {
@ -41,7 +41,7 @@ fn main() {
let backend: NativeBlas<f32, Native<_>> = Default::default();
let optimizer = Adam::default();
// let mut model = MnistDenseModel::new(28, 28, 1);
let mut model = MnistConvModel::new(28, 28, 1);
model.init(&backend);
@ -53,7 +53,13 @@ fn main() {
let loss = CrossEntropyLoss::new();
let Mnist { trn_img, trn_lbl, tst_img, tst_lbl, .. } = MnistBuilder::new()
let Mnist {
trn_img,
trn_lbl,
tst_img,
tst_lbl,
..
} = MnistBuilder::new()
.base_path("./datasets/mnist")
.label_format_digit()
.finalize();
@ -81,14 +87,14 @@ fn main() {
backend.load_tensor_u8(&mut targets0, &tmp[..]);
for epoch in 1 ..= 4 {
for epoch in 1..=4 {
println!("epoch {}", epoch);
for step in 0 .. (60000 / BATCH_SIZE) {
for step in 0..(60000 / BATCH_SIZE) {
let offset = step * BATCH_SIZE;
let mut tmp = [0u8; 10 * BATCH_SIZE];
let inputs_slice = &trn_img[offset * 784 .. (offset + BATCH_SIZE) * 784 ];
let inputs_slice = &trn_img[offset * 784..(offset + BATCH_SIZE) * 784];
let targets_slice = &trn_lbl[offset..offset + BATCH_SIZE];
backend.load_tensor_u8(&mut inputs, inputs_slice);
@ -102,13 +108,16 @@ fn main() {
model.forward(&backend, &inputs, &mut train_ctx);
loss.derivative(&backend, &mut deltas, train_ctx.outputs(), &targets);
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
model.calc_gradients(&backend, &deltas, &inputs, &mut train_ctx);
model.optimize(&backend, &optimizer);
}
model.forward(&backend, &inputs0, &mut test_ctx);
println!("Accuracy {}", calc_accuracy(&backend, test_ctx.outputs(), targets0_slice));
println!(
"Accuracy {}",
calc_accuracy(&backend, test_ctx.outputs(), targets0_slice)
);
}
}

View File

@ -2,7 +2,6 @@ use yarnn::native::Native;
use yarnn::optimizers::Adam;
use yarnn_model_vgg16::Vgg16Model;
fn main() {
let vgg16: Vgg16Model<f32, Native<_>, Adam<_, _>> = Vgg16Model::new(224, 224, 3);

View File

@ -1,11 +1,11 @@
#![feature(trait_alias)]
pub use self::dense::MnistDenseModel;
pub use self::conv::MnistConvModel;
pub use self::dense::MnistDenseModel;
mod dense {
use yarnn::model;
use yarnn::layers::*;
use yarnn::model;
model! {
MnistDenseModel (h: u32, w: u32, _c: u32) {
@ -25,11 +25,10 @@ mod dense {
}
}
mod conv {
use yarnn::model;
use yarnn::layers::*;
use yarnn::model;
model! {
MnistConvModel (h: u32, w: u32, c: u32) {
input_shape: (c, h, w),

View File

@ -1,3 +1,3 @@
fn main() {
println!("cargo:rustc-link-lib=dylib=cblas");
}
}

View File

@ -1,110 +1,147 @@
#[allow(dead_code)]
fn img_to_col_get_pixel(img: &[f32], img_rows: usize, img_cols: usize,
mut row: isize, mut col: isize, channel: usize,
pad_row: usize, pad_col: usize) -> f32
{
fn img_to_col_get_pixel(
img: &[f32],
img_rows: usize,
img_cols: usize,
mut row: isize,
mut col: isize,
channel: usize,
pad_row: usize,
pad_col: usize,
) -> f32 {
row -= pad_row as isize;
col -= pad_col as isize;
if row < 0 || col < 0 ||
row >= img_rows as isize ||
col >= img_cols as isize { return 0.0 }
if row < 0 || col < 0 || row >= img_rows as isize || col >= img_cols as isize {
return 0.0;
}
img[(channel * img_rows + row as usize) * img_cols + col as usize]
}
#[allow(dead_code)]
pub fn img_to_col(col: &mut [f32], img: &[f32], channels: usize,
k_rows: usize, k_cols: usize,
img_rows: usize, img_cols: usize,
s_row: usize, s_col: usize,
pad_row: usize, pad_col: usize)
{
pub fn img_to_col(
col: &mut [f32],
img: &[f32],
channels: usize,
k_rows: usize,
k_cols: usize,
img_rows: usize,
img_cols: usize,
s_row: usize,
s_col: usize,
pad_row: usize,
pad_col: usize,
) {
let col_rows = (img_rows + 2 * pad_row - k_rows) / s_row + 1;
let col_cols = (img_cols + 2 * pad_col - k_cols) / s_col + 1;
let k_size = k_rows * k_cols;
let channels_col = channels * k_size;
let out_size = col_rows * col_cols;
let col_size = channels_col * out_size;
let col_s = &mut col[0..col_size];
for ch in 0..channels_col {
let offset_ch = ch / k_rows / k_cols;
let offset_row = (ch / k_rows) % k_cols;
let offset_col = ch % k_rows;
for row in 0..col_rows {
for col in 0..col_cols {
let img_row = row * s_row + offset_row;
let img_col = col * s_col + offset_col;
let index_row = row * col_cols + col;
let index_col = offset_row * k_rows + offset_col;
let index = offset_ch * (k_size * out_size) + index_row * k_size + index_col;
col_s[index] = img_to_col_get_pixel(
img, img_rows, img_cols,
img_row as isize, img_col as isize,
offset_ch, pad_row, pad_col
img,
img_rows,
img_cols,
img_row as isize,
img_col as isize,
offset_ch,
pad_row,
pad_col,
);
}
}
}
}
fn col_to_img_add_pixel(img: &mut [f32], img_rows: usize, img_cols: usize,
mut row: isize, mut col: isize, channel: usize,
pad_row: usize, pad_col: usize, val: f32) {
fn col_to_img_add_pixel(
img: &mut [f32],
img_rows: usize,
img_cols: usize,
mut row: isize,
mut col: isize,
channel: usize,
pad_row: usize,
pad_col: usize,
val: f32,
) {
row -= pad_row as isize;
col -= pad_col as isize;
if row < 0 || col < 0 ||
row >= img_rows as isize ||
col >= img_cols as isize { return; }
if row < 0 || col < 0 || row >= img_rows as isize || col >= img_cols as isize {
return;
}
img[(channel * img_rows + row as usize) * img_cols + col as usize] += val;
}
#[allow(dead_code)]
pub fn col_to_img(img: &mut [f32], col: &[f32], channels: usize,
k_rows: usize, k_cols: usize,
img_rows: usize, img_cols: usize,
s_row: usize, s_col: usize,
pad_row: usize, pad_col: usize) {
pub fn col_to_img(
img: &mut [f32],
col: &[f32],
channels: usize,
k_rows: usize,
k_cols: usize,
img_rows: usize,
img_cols: usize,
s_row: usize,
s_col: usize,
pad_row: usize,
pad_col: usize,
) {
let col_rows = (img_rows + 2 * pad_row - k_rows) / s_row + 1;
let col_cols = (img_cols + 2 * pad_col - k_cols) / s_col + 1;
let k_size = k_rows * k_cols;
let channels_col = channels * k_size;
let out_size = col_rows * col_cols;
let col_size = channels_col * out_size;
let col_s = &col[0..col_size];
for ch in 0..channels_col {
let offset_ch = ch / k_rows / k_cols;
let offset_row = (ch / k_rows) % k_cols;
let offset_col = ch % k_rows;
for row in 0..col_rows {
for col in 0..col_cols {
let img_row = row * s_row + offset_row;
let img_col = col * s_col + offset_col;
let index_row = row * col_cols + col;
let index_col = offset_row * k_rows + offset_col;
let index = offset_ch * (k_size * out_size) + index_row * k_size + index_col;
col_to_img_add_pixel(
img, img_rows, img_cols,
img_row as isize, img_col as isize,
offset_ch, pad_row, pad_col,
col_s[index]
img,
img_rows,
img_cols,
img_row as isize,
img_col as isize,
offset_ch,
pad_row,
pad_col,
col_s[index],
);
}
}
@ -118,77 +155,48 @@ mod tests {
#[test]
fn test_img_to_col() {
let img: &[f32] = &[
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
1.5, 2.5, 3.5, 4.5,
5.5, 6.5, 7.5, 8.5,
9.5, 10.5, 11.5, 12.5,
13.5, 14.5, 15.5, 16.5,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5,
];
let mut col = vec![0.0; 72];
img_to_col(&mut col, img, 2, 3, 3, 4, 4, 1, 1, 0, 0);
let tmp: &[f32] = &[
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0,
2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0, 11.0, 12.0,
5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0,
6.0, 7.0, 8.0, 10.0, 11.0, 12.0, 14.0, 15.0, 16.0,
1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5,
2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5,
5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0,
11.0, 12.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0, 6.0, 7.0, 8.0, 10.0,
11.0, 12.0, 14.0, 15.0, 16.0, 1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 2.5, 3.5,
4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
];
assert_eq!(
col.as_slice(),
tmp
);
assert_eq!(col.as_slice(), tmp);
}
#[test]
fn test_col_to_img() {
let col: &[f32] = &[
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0,
2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0, 11.0, 12.0,
5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0,
6.0, 7.0, 8.0, 10.0, 11.0, 12.0, 14.0, 15.0, 16.0,
1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5,
2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5,
5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0,
11.0, 12.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0, 6.0, 7.0, 8.0, 10.0,
11.0, 12.0, 14.0, 15.0, 16.0, 1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 2.5, 3.5,
4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
col_to_img(y, col, 2, 3, 3, 4, 4, 1, 1, 0, 0);
let tmp: &[f32] = &[
1.0, 4.0, 6.0, 4.0,
10.0, 24.0, 28.0, 16.0,
18.0, 40.0, 44.0, 24.0,
13.0, 28.0, 30.0, 16.0,
1.5, 5.0, 7.0, 4.5,
11.0, 26.0, 30.0, 17.0,
19.0, 42.0, 46.0, 25.0,
13.5, 29.0, 31.0, 16.5,
1.0, 4.0, 6.0, 4.0, 10.0, 24.0, 28.0, 16.0, 18.0, 40.0, 44.0, 24.0, 13.0, 28.0, 30.0,
16.0, 1.5, 5.0, 7.0, 4.5, 11.0, 26.0, 30.0, 17.0, 19.0, 42.0, 46.0, 25.0, 13.5, 29.0,
31.0, 16.5,
];
assert_eq!(y, tmp);
}
}
}

View File

@ -1,23 +1,25 @@
mod img2col;
use std::marker::PhantomData;
use yarnn::backend::*;
use yarnn::native::*;
use yarnn::tensor::*;
use std::marker::PhantomData;
extern crate openblas_src;
pub struct NativeBlas<N, B>
where N: NativeNumber,
B: NativeBackend<N>
pub struct NativeBlas<N, B>
where
N: NativeNumber,
B: NativeBackend<N>,
{
inner: B,
_m: PhantomData<fn(N)>
_m: PhantomData<fn(N)>,
}
impl<N, B> Default for NativeBlas<N, B>
where N: NativeNumber,
B: NativeBackend<N>
where
N: NativeNumber,
B: NativeBackend<N>,
{
fn default() -> Self {
Self {
@ -27,9 +29,10 @@ impl<N, B> Default for NativeBlas<N, B>
}
}
impl<N, B> NativeBackend<N> for NativeBlas<N, B>
where N: NativeNumber,
B: NativeBackend<N>
impl<N, B> NativeBackend<N> for NativeBlas<N, B>
where
N: NativeNumber,
B: NativeBackend<N>,
{
#[inline]
fn read_tensor<'a>(&self, t: &'a Self::Tensor) -> &'a [N] {
@ -42,21 +45,23 @@ impl<N, B> NativeBackend<N> for NativeBlas<N, B>
}
}
impl<N, B> NativeBlas<N, B>
where N: NativeNumber,
B: NativeBackend<N>
impl<N, B> NativeBlas<N, B>
where
N: NativeNumber,
B: NativeBackend<N>,
{
pub fn new(inner: B) -> Self {
Self {
inner,
_m: Default::default()
_m: Default::default(),
}
}
}
impl<N, B> Backend<N> for NativeBlas<N, B>
where N: NativeNumber,
B: NativeBackend<N>
impl<N, B> Backend<N> for NativeBlas<N, B>
where
N: NativeNumber,
B: NativeBackend<N>,
{
type Tensor = B::Tensor;
@ -96,8 +101,9 @@ impl<N, B> Backend<N> for NativeBlas<N, B>
}
}
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
where
B: NativeBackend<f32>,
{
#[inline]
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
@ -114,15 +120,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
let m = a_shape.get(0) as i32;
let n = b_shape.get(1) as i32;
let k = b_shape.get(0) as i32;
unsafe {
blas::sgemm('N' as u8, 'N' as u8,
n, m, k,
1.0,
self.read_tensor(b), n,
self.read_tensor(a), k,
0.0,
self.write_tensor(dst), n);
blas::sgemm(
'N' as u8,
'N' as u8,
n,
m,
k,
1.0,
self.read_tensor(b),
n,
self.read_tensor(a),
k,
0.0,
self.write_tensor(dst),
n,
);
}
}
@ -141,15 +155,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
let m = a_shape.get(0) as i32;
let n = b_shape.get(0) as i32;
let k = b_shape.get(1) as i32;
unsafe {
blas::sgemm('T' as u8, 'N' as u8,
n, m, k,
1.0,
self.read_tensor(b), k,
self.read_tensor(a), k,
0.0,
self.write_tensor(dst), n);
blas::sgemm(
'T' as u8,
'N' as u8,
n,
m,
k,
1.0,
self.read_tensor(b),
k,
self.read_tensor(a),
k,
0.0,
self.write_tensor(dst),
n,
);
}
}
@ -168,15 +190,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
let m = a_shape.get(1) as i32;
let n = b_shape.get(1) as i32;
let k = b_shape.get(0) as i32;
unsafe {
blas::sgemm('N' as u8, 'T' as u8,
n, m, k,
1.0,
self.read_tensor(b), n,
self.read_tensor(a), m,
0.0,
self.write_tensor(dst), n);
blas::sgemm(
'N' as u8,
'T' as u8,
n,
m,
k,
1.0,
self.read_tensor(b),
n,
self.read_tensor(a),
m,
0.0,
self.write_tensor(dst),
n,
);
}
}
@ -186,8 +216,9 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
}
}
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
where
B: NativeBackend<f32>,
{
#[inline]
fn axpy(&self, dst: &mut Self::Tensor, scale: f32, x: &Self::Tensor) {
@ -202,31 +233,26 @@ impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
self.read_tensor(x),
1,
self.write_tensor(dst),
1
1,
);
}
}
}
impl<B> BackendScale<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
impl<B> BackendScale<f32> for NativeBlas<f32, B>
where
B: NativeBackend<f32>,
{
#[inline]
fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
let dst_size = dst.shape().size();
unsafe {
blas::sscal(
dst_size as i32,
scale,
self.write_tensor(dst),
1
);
blas::sscal(dst_size as i32, scale, self.write_tensor(dst), 1);
}
}
}
impl<B: NativeBackend<f32> + BackendSigmoid<f32>> BackendSigmoid<f32> for NativeBlas<f32, B> {
#[inline]
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
@ -265,7 +291,13 @@ impl<B: NativeBackend<f32> + BackendBias<f32>> BackendBias<f32> for NativeBlas<f
impl<B: NativeBackend<f32> + BackendMse<f32>> BackendMse<f32> for NativeBlas<f32, B> {
#[inline]
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: f32) {
fn scaled_square_diff(
&self,
dst: &mut Self::Tensor,
a: &Self::Tensor,
b: &Self::Tensor,
scale: f32,
) {
self.inner.scaled_square_diff(dst, a, b, scale)
}
@ -303,7 +335,6 @@ impl<B: NativeBackend<f32> + BackendMul<f32>> BackendMul<f32> for NativeBlas<f32
}
}
impl<B: NativeBackend<f32> + BackendCopy<f32>> BackendCopy<f32> for NativeBlas<f32, B> {
#[inline]
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
@ -318,10 +349,16 @@ impl<B: NativeBackend<f32> + BackendMaximum<f32>> BackendMaximum<f32> for Native
}
}
impl<B: NativeBackend<f32> + BackendAdam<f32>> BackendAdam<f32> for NativeBlas<f32, B> {
#[inline]
fn adam_p(&self, dst: &mut Self::Tensor, lr: f32, moms: &Self::Tensor, vels: &Self::Tensor, eps: f32) {
fn adam_p(
&self,
dst: &mut Self::Tensor,
lr: f32,
moms: &Self::Tensor,
vels: &Self::Tensor,
eps: f32,
) {
self.inner.adam_p(dst, lr, moms, vels, eps)
}
}
@ -337,17 +374,35 @@ impl<B: NativeBackend<f32> + BackendConv2d<f32>> BackendConv2d<f32> for NativeBl
type Context = ();
#[inline]
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
fn conv2d_forward(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
w: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
self.inner.conv2d_forward(y, x, w, conv_info)
}
#[inline]
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
self.inner.conv2d_backward_input(dx, dy, w, conv_info)
fn conv2d_backward_input(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
w: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
self.inner.conv2d_backward_input(dx, dy, w, conv_info)
}
#[inline]
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
fn conv2d_backward_filter(
&self,
dw: &mut Self::Tensor,
x: &Self::Tensor,
dy: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
self.inner.conv2d_backward_filter(dw, x, dy, conv_info)
}
}
@ -359,7 +414,13 @@ impl<B: NativeBackend<f32> + BackendMaxPool2d<f32>> BackendMaxPool2d<f32> for Na
}
#[inline]
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
fn max_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
self.inner.max_pool2d_backprop(dx, dy, x, conv_info)
}
}
@ -371,14 +432,28 @@ impl<B: NativeBackend<f32> + BackendAvgPool2d<f32>> BackendAvgPool2d<f32> for Na
}
#[inline]
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
fn avg_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
self.inner.avg_pool2d_backprop(dx, dy, x, conv_info)
}
}
impl<B: NativeBackend<f32> + BackendPaddingCopy2d<f32>> BackendPaddingCopy2d<f32> for NativeBlas<f32, B> {
impl<B: NativeBackend<f32> + BackendPaddingCopy2d<f32>> BackendPaddingCopy2d<f32>
for NativeBlas<f32, B>
{
#[inline]
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
fn copy_with_padding2d(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
y_paddings: (u32, u32),
x_paddings: (u32, u32),
) {
self.inner.copy_with_padding2d(y, x, y_paddings, x_paddings)
}
}

View File

@ -1,5 +1,4 @@
use crate::tensor::{Tensor};
use crate::tensor::Tensor;
pub trait Backend<N> {
type Tensor: Tensor<N>;
@ -13,22 +12,22 @@ pub trait Backend<N> {
fn print_tensor(&self, t: &Self::Tensor);
}
impl <'a, N, T: Backend<N>> Backend<N> for &'a T {
impl<'a, N, T: Backend<N>> Backend<N> for &'a T {
type Tensor = T::Tensor;
#[inline]
fn store_tensor_f32(&self, dst: &Self::Tensor, slice: &mut [f32]) {
(**self).store_tensor_f32(dst, slice)
(**self).store_tensor_f32(dst, slice)
}
#[inline]
fn load_tensor_u8(&self, dst: &mut Self::Tensor, slice: &[u8]) {
(**self).load_tensor_u8(dst, slice)
(**self).load_tensor_u8(dst, slice)
}
#[inline]
fn load_tensor_f32(&self, dst: &mut Self::Tensor, slice: &[f32]) {
(**self).load_tensor_f32(dst, slice)
(**self).load_tensor_f32(dst, slice)
}
#[inline]
@ -59,7 +58,7 @@ pub trait BackendGemm<N>: Backend<N> {
fn matmul_tt(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor);
}
impl <'a, N, T: BackendGemm<N>> BackendGemm<N> for &'a T {
impl<'a, N, T: BackendGemm<N>> BackendGemm<N> for &'a T {
#[inline]
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
(**self).matmul(dst, a, b)
@ -86,7 +85,7 @@ pub trait BackendBias<N>: Backend<N> {
fn bias_grad(&self, bias: &mut Self::Tensor, inputs: &Self::Tensor);
}
impl <'a, N, T: BackendBias<N>> BackendBias<N> for &'a T {
impl<'a, N, T: BackendBias<N>> BackendBias<N> for &'a T {
#[inline]
fn bias_add(&self, dst: &mut Self::Tensor, bias: &Self::Tensor) {
(**self).bias_add(dst, bias)
@ -103,7 +102,7 @@ pub trait BackendSigmoid<N>: Backend<N> {
fn sigmoid_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor);
}
impl <'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
impl<'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
#[inline]
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
(**self).sigmoid(dst, data)
@ -115,13 +114,12 @@ impl <'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
}
}
pub trait BackendReLu<N>: Backend<N> {
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor);
fn relu_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor);
}
impl <'a, N, T: BackendReLu<N>> BackendReLu<N> for &'a T {
impl<'a, N, T: BackendReLu<N>> BackendReLu<N> for &'a T {
#[inline]
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
(**self).relu(dst, data)
@ -137,7 +135,7 @@ pub trait BackendScale<N>: Backend<N> {
fn scale(&self, dst: &mut Self::Tensor, scale: N);
}
impl <'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
impl<'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
#[inline]
fn scale(&self, dst: &mut Self::Tensor, scale: N) {
(**self).scale(dst, scale)
@ -145,13 +143,25 @@ impl <'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
}
pub trait BackendMse<N>: Backend<N> {
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N);
fn scaled_square_diff(
&self,
dst: &mut Self::Tensor,
a: &Self::Tensor,
b: &Self::Tensor,
scale: N,
);
fn scaled_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N);
}
impl <'a, N, T: BackendMse<N>> BackendMse<N> for &'a T {
impl<'a, N, T: BackendMse<N>> BackendMse<N> for &'a T {
#[inline]
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N) {
fn scaled_square_diff(
&self,
dst: &mut Self::Tensor,
a: &Self::Tensor,
b: &Self::Tensor,
scale: N,
) {
(**self).scaled_square_diff(dst, a, b, scale)
}
@ -165,7 +175,7 @@ pub trait BackendAxpy<N>: Backend<N> {
fn axpy(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor);
}
impl <'a, N, T: BackendAxpy<N>> BackendAxpy<N> for &'a T {
impl<'a, N, T: BackendAxpy<N>> BackendAxpy<N> for &'a T {
#[inline]
fn axpy(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor) {
(**self).axpy(dst, scale, a)
@ -176,7 +186,7 @@ pub trait BackendAxpys<N>: Backend<N> {
fn axpys(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor);
}
impl <'a, N, T: BackendAxpys<N>> BackendAxpys<N> for &'a T {
impl<'a, N, T: BackendAxpys<N>> BackendAxpys<N> for &'a T {
#[inline]
fn axpys(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor) {
(**self).axpys(dst, scale, a)
@ -187,7 +197,7 @@ pub trait BackendAdd<N>: Backend<N> {
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
}
impl <'a, N, T: BackendAdd<N>> BackendAdd<N> for &'a T {
impl<'a, N, T: BackendAdd<N>> BackendAdd<N> for &'a T {
#[inline]
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
(**self).add(dst, a)
@ -198,7 +208,7 @@ pub trait BackendSub<N>: Backend<N> {
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor);
}
impl <'a, N, T: BackendSub<N>> BackendSub<N> for &'a T {
impl<'a, N, T: BackendSub<N>> BackendSub<N> for &'a T {
#[inline]
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
(**self).sub(dst, a, b)
@ -209,7 +219,7 @@ pub trait BackendMul<N>: Backend<N> {
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
}
impl <'a, N, T: BackendMul<N>> BackendMul<N> for &'a T {
impl<'a, N, T: BackendMul<N>> BackendMul<N> for &'a T {
#[inline]
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
(**self).mul(dst, a)
@ -220,7 +230,7 @@ pub trait BackendCopy<N>: Backend<N> {
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
}
impl <'a, N, T: BackendCopy<N>> BackendCopy<N> for &'a T {
impl<'a, N, T: BackendCopy<N>> BackendCopy<N> for &'a T {
#[inline]
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
(**self).copy(dst, a)
@ -231,20 +241,36 @@ pub trait BackendMaximum<N>: Backend<N> {
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
}
impl <'a, N, T: BackendMaximum<N>> BackendMaximum<N> for &'a T {
impl<'a, N, T: BackendMaximum<N>> BackendMaximum<N> for &'a T {
#[inline]
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
(**self).maximum(dst, a)
}
}
pub trait BackendAdam<N>: BackendScale<N> + BackendAxpy<N> + BackendAxpys<N> + BackendMaximum<N> {
fn adam_p(&self, dst: &mut Self::Tensor, lr: N, moms: &Self::Tensor, vels: &Self::Tensor, eps: N);
pub trait BackendAdam<N>:
BackendScale<N> + BackendAxpy<N> + BackendAxpys<N> + BackendMaximum<N>
{
fn adam_p(
&self,
dst: &mut Self::Tensor,
lr: N,
moms: &Self::Tensor,
vels: &Self::Tensor,
eps: N,
);
}
impl <'a, N, T: BackendAdam<N>> BackendAdam<N> for &'a T {
impl<'a, N, T: BackendAdam<N>> BackendAdam<N> for &'a T {
#[inline]
fn adam_p(&self, dst: &mut Self::Tensor, lr: N, moms: &Self::Tensor, vels: &Self::Tensor, eps: N) {
fn adam_p(
&self,
dst: &mut Self::Tensor,
lr: N,
moms: &Self::Tensor,
vels: &Self::Tensor,
eps: N,
) {
(**self).adam_p(dst, lr, moms, vels, eps)
}
}
@ -253,7 +279,7 @@ pub trait BackendSoftmax<N>: BackendCopy<N> {
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor);
}
impl <'a, N, T: BackendSoftmax<N>> BackendSoftmax<N> for &'a T {
impl<'a, N, T: BackendSoftmax<N>> BackendSoftmax<N> for &'a T {
#[inline]
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor) {
(**self).softmax(y, x)
@ -269,7 +295,7 @@ pub enum PaddingKind {
#[derive(Clone, PartialEq, Debug)]
pub struct Conv2dInfo {
pub padding: PaddingKind,
pub padding: PaddingKind,
pub strides: (u32, u32),
pub kernel: (u32, u32),
}
@ -277,26 +303,62 @@ pub struct Conv2dInfo {
pub trait BackendConv2d<N>: Backend<N> {
type Context;
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, filter: &Self::Tensor, conv_info: &Conv2dInfo);
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, filter: &Self::Tensor, conv_info: &Conv2dInfo);
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo);
fn conv2d_forward(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
filter: &Self::Tensor,
conv_info: &Conv2dInfo,
);
fn conv2d_backward_input(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
filter: &Self::Tensor,
conv_info: &Conv2dInfo,
);
fn conv2d_backward_filter(
&self,
dw: &mut Self::Tensor,
x: &Self::Tensor,
dy: &Self::Tensor,
conv_info: &Conv2dInfo,
);
}
impl <'a, N, T: BackendConv2d<N>> BackendConv2d<N> for &'a T {
impl<'a, N, T: BackendConv2d<N>> BackendConv2d<N> for &'a T {
type Context = ();
#[inline]
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, filters: &Self::Tensor, conv_info: &Conv2dInfo) {
fn conv2d_forward(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
filters: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
(**self).conv2d_forward(y, x, filters, conv_info)
}
#[inline]
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, filters: &Self::Tensor, conv_info: &Conv2dInfo) {
fn conv2d_backward_input(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
filters: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
(**self).conv2d_backward_input(dx, dy, filters, conv_info)
}
#[inline]
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
fn conv2d_backward_filter(
&self,
dw: &mut Self::Tensor,
x: &Self::Tensor,
dy: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
(**self).conv2d_forward(dw, x, dy, conv_info)
}
}
@ -309,45 +371,81 @@ pub enum PoolingKind {
pub trait BackendMaxPool2d<N>: Backend<N> {
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
fn max_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
);
}
impl <'a, N, T: BackendMaxPool2d<N>> BackendMaxPool2d<N> for &'a T {
impl<'a, N, T: BackendMaxPool2d<N>> BackendMaxPool2d<N> for &'a T {
#[inline]
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
(**self).max_pool2d(y, x, conv_info)
}
#[inline]
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
fn max_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
(**self).max_pool2d_backprop(dx, dy, x, conv_info)
}
}
pub trait BackendAvgPool2d<N>: Backend<N> {
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
fn avg_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
);
}
impl <'a, N, T: BackendAvgPool2d<N>> BackendAvgPool2d<N> for &'a T {
impl<'a, N, T: BackendAvgPool2d<N>> BackendAvgPool2d<N> for &'a T {
#[inline]
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
(**self).avg_pool2d(y, x, conv_info)
}
#[inline]
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
fn avg_pool2d_backprop(
&self,
dx: &mut Self::Tensor,
dy: &Self::Tensor,
x: &Self::Tensor,
conv_info: &Conv2dInfo,
) {
(**self).avg_pool2d_backprop(dx, dy, x, conv_info)
}
}
pub trait BackendPaddingCopy2d<N>: Backend<N> {
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32));
fn copy_with_padding2d(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
y_paddings: (u32, u32),
x_paddings: (u32, u32),
);
}
impl <'a, N, T: BackendPaddingCopy2d<N>> BackendPaddingCopy2d<N> for &'a T {
impl<'a, N, T: BackendPaddingCopy2d<N>> BackendPaddingCopy2d<N> for &'a T {
#[inline]
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
fn copy_with_padding2d(
&self,
y: &mut Self::Tensor,
x: &Self::Tensor,
y_paddings: (u32, u32),
x_paddings: (u32, u32),
) {
(**self).copy_with_padding2d(y, x, y_paddings, x_paddings)
}
}

View File

@ -4,10 +4,10 @@ use crate::tensor::{Tensor, TensorShape};
// use core::marker::PhantomData;
pub trait Layer<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
type Context: LayerContext<N, B>;
@ -15,7 +15,7 @@ pub trait Layer<N, B, O>
fn param_count(&self) -> usize {
0
}
#[inline]
fn init(&mut self, _backend: &B) {}
@ -25,41 +25,58 @@ pub trait Layer<N, B, O>
fn output_shape(&self) -> TensorShape {
self.input_shape()
}
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context);
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context);
#[inline]
fn calc_gradients(&mut self, _backend: &B, _dy: &B::Tensor, _x: &B::Tensor, _ctx: &mut Self::Context) {}
fn calc_gradients(
&mut self,
_backend: &B,
_dy: &B::Tensor,
_x: &B::Tensor,
_ctx: &mut Self::Context,
) {
}
#[inline]
fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
writeln!(f, "{}{} -> {}[{}] -> {}", "".repeat(padding), self.input_shape(), self.name(), self.param_count(), self.output_shape())?;
writeln!(
f,
"{}{} -> {}[{}] -> {}",
"".repeat(padding),
self.input_shape(),
self.name(),
self.param_count(),
self.output_shape()
)?;
Ok(())
}
}
pub trait LayerExt<N, B, O>: Layer<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
type Config: Default;
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
#[inline]
fn add_layer<L: LayerExt<N, B, O>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, L>
where Self: Sized
fn add_layer<L: LayerExt<N, B, O>>(
self,
cfg: L::Config,
) -> crate::layers::Chain<N, B, O, Self, L>
where
Self: Sized,
{
let shape = self.output_shape();
crate::layers::Chain::new(
self,
L::create(shape, cfg),
)
crate::layers::Chain::new(self, L::create(shape, cfg))
}
}
@ -68,16 +85,17 @@ pub trait LayerContext<N, B: Backend<N>>: Default {
fn deltas(&self) -> &B::Tensor;
}
pub struct DefaultLayerContext<N, B>
where B: Backend<N>,
pub struct DefaultLayerContext<N, B>
where
B: Backend<N>,
{
pub outputs: B::Tensor,
pub deltas: B::Tensor,
}
impl <N, B> Default for DefaultLayerContext<N, B>
where B: Backend<N>,
impl<N, B> Default for DefaultLayerContext<N, B>
where
B: Backend<N>,
{
fn default() -> Self {
Self {
@ -87,8 +105,9 @@ impl <N, B> Default for DefaultLayerContext<N, B>
}
}
impl <N, B> DefaultLayerContext<N, B>
where B: Backend<N>,
impl<N, B> DefaultLayerContext<N, B>
where
B: Backend<N>,
{
pub fn update_deltas_shape(&mut self, bs: u32, input_shape: &TensorShape) {
let mut new_deltas_shape = TensorShape::new1d(bs);
@ -110,8 +129,9 @@ impl <N, B> DefaultLayerContext<N, B>
}
}
impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
where B: Backend<N>,
impl<N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
where
B: Backend<N>,
{
#[inline]
fn outputs(&self) -> &B::Tensor {
@ -122,4 +142,4 @@ impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
fn deltas(&self) -> &B::Tensor {
&self.deltas
}
}
}

View File

@ -1,7 +1,7 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::backend::{Backend, PaddingKind, BackendAvgPool2d, Conv2dInfo};
use crate::backend::{Backend, BackendAvgPool2d, Conv2dInfo, PaddingKind};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
@ -19,24 +19,26 @@ impl Default for AvgPool2dConfig {
}
}
pub struct AvgPool2d<N, B>
where B: Backend<N>
pub struct AvgPool2d<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
conv_info: Conv2dInfo,
_m: PhantomData<fn(N, B)>
_m: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
where B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
where
B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"AvgPool2d"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -51,13 +53,9 @@ impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
let rows = (is[0] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
let cols = (is[1] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
TensorShape::new3d(
is[0],
rows,
cols,
)
TensorShape::new3d(is[0], rows, cols)
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
@ -83,9 +81,10 @@ impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
where B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
where
B: Backend<N> + BackendAvgPool2d<N>,
O: Optimizer<N, B>,
{
type Config = AvgPool2dConfig;

View File

@ -5,20 +5,22 @@ use crate::tensor::TensorShape;
use core::marker::PhantomData;
pub struct ChainContext<N, B, L, R>
where B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
pub struct ChainContext<N, B, L, R>
where
B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
{
left: L,
left: L,
right: R,
_m: PhantomData<fn(N, B)>
_m: PhantomData<fn(N, B)>,
}
impl<N, B, L, R> Default for ChainContext<N, B, L, R>
where B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
impl<N, B, L, R> Default for ChainContext<N, B, L, R>
where
B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
{
fn default() -> Self {
Self {
@ -30,9 +32,10 @@ impl<N, B, L, R> Default for ChainContext<N, B, L, R>
}
impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
where B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
where
B: Backend<N>,
L: LayerContext<N, B>,
R: LayerContext<N, B>,
{
#[inline]
fn outputs(&self) -> &B::Tensor {
@ -43,24 +46,26 @@ impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
fn deltas(&self) -> &B::Tensor {
self.left.deltas()
}
}
pub struct Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
left: L,
right: R,
_m: PhantomData<fn(N, B, O)>
}
impl<N, B, O, L, R> Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
pub struct Chain<N, B, O, L, R>
where
B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
left: L,
right: R,
_m: PhantomData<fn(N, B, O)>,
}
impl<N, B, O, L, R> Chain<N, B, O, L, R>
where
B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
pub fn new(left: L, right: R) -> Self {
Self {
@ -71,7 +76,7 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
}
}
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
// where B: Backend<N>,
// O: Optimizer<N, B>,
// L: Layer<N, B, O>,
@ -85,11 +90,12 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
// }
// }
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
where B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
where
B: Backend<N>,
O: Optimizer<N, B>,
L: Layer<N, B, O>,
R: Layer<N, B, O>,
{
type Context = ChainContext<N, B, L::Context, R::Context>;
@ -101,7 +107,7 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
#[inline]
fn param_count(&self) -> usize {
self.left.param_count() + self.right.param_count()
}
}
#[inline]
fn init(&mut self, backend: &B) {
@ -122,19 +128,36 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
#[inline]
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.left.forward(backend, inputs, &mut ctx.left);
self.right.forward(backend, ctx.left.outputs(), &mut ctx.right);
self.right
.forward(backend, ctx.left.outputs(), &mut ctx.right);
}
#[inline]
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.right.backward(backend, deltas, ctx.left.outputs(), &mut ctx.right);
self.left.backward(backend, ctx.right.deltas(), inputs, &mut ctx.left);
fn backward(
&mut self,
backend: &B,
deltas: &B::Tensor,
inputs: &B::Tensor,
ctx: &mut Self::Context,
) {
self.right
.backward(backend, deltas, ctx.left.outputs(), &mut ctx.right);
self.left
.backward(backend, ctx.right.deltas(), inputs, &mut ctx.left);
}
#[inline]
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
self.left.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
self.right.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
fn calc_gradients(
&mut self,
backend: &B,
deltas: &B::Tensor,
inputs: &B::Tensor,
ctx: &mut Self::Context,
) {
self.left
.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
self.right
.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
}
#[inline]
@ -146,7 +169,7 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
self.left.fmt(f, padding)?;
self.right.fmt(f, padding)?;
Ok(())
}
}
}

View File

@ -1,8 +1,8 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::params::Params;
use crate::backend::{Backend, Conv2dInfo, PaddingKind, BackendBias, BackendConv2d, BackendScale};
use crate::backend::{Backend, BackendBias, BackendConv2d, BackendScale, Conv2dInfo, PaddingKind};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::params::Params;
use crate::tensor::{Tensor, TensorShape};
pub struct Conv2dConfig {
pub filters: u32,
@ -24,9 +24,10 @@ impl Default for Conv2dConfig {
}
}
pub struct Conv2d<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
pub struct Conv2d<N, B, O>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
input_shape: TensorShape,
units: u32,
@ -36,9 +37,10 @@ pub struct Conv2d<N, B, O>
biases: Params<N, B, O>,
}
impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
where
B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
@ -53,10 +55,13 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
} else {
self.filters.params.shape().size()
}
}
}
fn init(&mut self, backend: &B) {
self.filters.init_random(backend, self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0));
self.filters.init_random(
backend,
self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0),
);
if self.use_biases {
self.biases.init_zero(backend);
@ -77,13 +82,9 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
TensorShape::new3d(
self.units,
rows,
cols,
)
TensorShape::new3d(self.units, rows, cols)
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
assert_eq!(x.shape().dims, 4);
@ -101,14 +102,20 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
#[inline]
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
assert_eq!(dy.shape().dims, 4);
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
backend.conv2d_backward_input(&mut ctx.deltas, dy, &self.filters.params, &self.conv_info);
}
#[inline]
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
fn calc_gradients(
&mut self,
backend: &B,
dy: &B::Tensor,
x: &B::Tensor,
_ctx: &mut Self::Context,
) {
assert_eq!(dy.shape().dims, 4);
assert_eq!(x.shape().dims, 4);
@ -120,18 +127,24 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
#[inline]
fn optimize(&mut self, backend: &B, optimizer: &O) {
optimizer.update_params(backend, &mut self.filters.ctx, &mut self.filters.params, &mut self.filters.grads);
optimizer.update_params(
backend,
&mut self.filters.ctx,
&mut self.filters.params,
&mut self.filters.grads,
);
if self.use_biases {
unimplemented!()
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
}
}
}
impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
where
B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>,
{
type Config = Conv2dConfig;
@ -143,12 +156,12 @@ impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
units: cfg.filters,
conv_info: Conv2dInfo {
kernel: cfg.kernel,
padding: cfg.padding,
padding: cfg.padding,
strides: cfg.strides,
},
use_biases: cfg.biases,
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
biases: Params::new((cfg.filters, )),
biases: Params::new((cfg.filters,)),
}
}
}
}

View File

@ -1,30 +1,31 @@
use crate::tensor::{TensorShape, Tensor};
use crate::backend::{Backend, BackendCopy};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
#[derive(Default)]
pub struct FlattenConfig;
pub struct Flatten<N, B>
where B: Backend<N>,
pub struct Flatten<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
_x: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for Flatten<N, B>
where
B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Flatten"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -34,11 +35,11 @@ impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
fn output_shape(&self) -> TensorShape {
TensorShape::new1d(self.input_shape.size() as u32)
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
backend.copy(&mut ctx.outputs, x);
}
@ -50,16 +51,17 @@ impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for Flatten<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for Flatten<N, B>
where
B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>,
{
type Config = FlattenConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Flatten {
input_shape,
_x: Default::default()
_x: Default::default(),
}
}
}
}

View File

@ -1,8 +1,8 @@
use crate::tensor::{Tensor, TensorShape};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::params::Params;
use crate::backend::{Backend, BackendGemm, BackendBias, BackendScale};
use crate::backend::{Backend, BackendBias, BackendGemm, BackendScale};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::params::Params;
use crate::tensor::{Tensor, TensorShape};
pub struct LinearConfig {
pub units: u32,
@ -18,9 +18,10 @@ impl Default for LinearConfig {
}
}
pub struct Linear<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
pub struct Linear<N, B, O>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
inputs: u32,
outputs: u32,
@ -29,9 +30,10 @@ pub struct Linear<N, B, O>
biases: Params<N, B, O>,
}
impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for Linear<N, B, O>
where
B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
@ -46,10 +48,11 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
} else {
self.weights.params.shape().size()
}
}
}
fn init(&mut self, backend: &B) {
self.weights.init_random(backend, self.inputs + self.outputs);
self.weights
.init_random(backend, self.inputs + self.outputs);
if self.use_biases {
self.biases.init_zero(backend);
}
@ -64,7 +67,7 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
fn output_shape(&self) -> TensorShape {
TensorShape::new1d(self.outputs)
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
@ -79,11 +82,17 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
#[inline]
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape());
backend.matmul_nt(&mut ctx.deltas, dy, &self.weights.params);
}
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
fn calc_gradients(
&mut self,
backend: &B,
dy: &B::Tensor,
x: &B::Tensor,
_ctx: &mut Self::Context,
) {
let prescaler = 1.0 / x.shape().get(0) as f32;
backend.matmul_tn(&mut self.weights.grads, x, dy);
@ -97,17 +106,28 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
#[inline]
fn optimize(&mut self, backend: &B, optimizer: &O) {
optimizer.update_params(backend, &mut self.weights.ctx, &mut self.weights.params, &mut self.weights.grads);
optimizer.update_params(
backend,
&mut self.weights.ctx,
&mut self.weights.params,
&mut self.weights.grads,
);
if self.use_biases {
optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &mut self.biases.grads);
optimizer.update_params(
backend,
&mut self.biases.ctx,
&mut self.biases.params,
&mut self.biases.grads,
);
}
}
}
impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
where
B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
O: Optimizer<N, B>,
{
type Config = LinearConfig;
@ -121,7 +141,7 @@ impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
outputs: cfg.units,
use_biases: cfg.biases,
weights: Params::new((inputs, cfg.units)),
biases: Params::new((cfg.units, )),
biases: Params::new((cfg.units,)),
}
}
}
}

View File

@ -1,7 +1,7 @@
use crate::tensor::{TensorShape, Tensor};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
use crate::backend::{Backend, BackendMaxPool2d, Conv2dInfo, PaddingKind};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
pub struct MaxPool2dConfig {
@ -18,24 +18,26 @@ impl Default for MaxPool2dConfig {
}
}
pub struct MaxPool2d<N, B>
where B: Backend<N>
pub struct MaxPool2d<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
conv_info: Conv2dInfo,
_m: PhantomData<fn(N, B)>
_m: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
where B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
where
B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"MaxPool2d"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -44,26 +46,22 @@ impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
#[inline]
fn output_shape(&self) -> TensorShape {
let is = self.input_shape.as_slice();
// O = (W - K + 2P) / S + 1
let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
TensorShape::new3d(
is[0],
rows,
cols,
)
TensorShape::new3d(is[0], rows, cols)
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
backend.max_pool2d(&mut ctx.outputs, x, &self.conv_info)
}
#[inline]
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
@ -72,9 +70,10 @@ impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
where B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
where
B: Backend<N> + BackendMaxPool2d<N>,
O: Optimizer<N, B>,
{
type Config = MaxPool2dConfig;
@ -91,4 +90,4 @@ impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
_m: Default::default(),
}
}
}
}

View File

@ -1,21 +1,21 @@
mod linear;
mod sigmoid;
mod chain;
mod relu;
mod softmax;
mod avgpool2d;
mod maxpool2d;
mod zeropadding2d;
mod chain;
mod conv2d;
mod flatten;
mod linear;
mod maxpool2d;
mod relu;
mod sigmoid;
mod softmax;
mod zeropadding2d;
pub use self::linear::*;
pub use self::sigmoid::*;
pub use self::chain::*;
pub use self::relu::*;
pub use self::softmax::*;
pub use self::conv2d::*;
pub use self::avgpool2d::*;
pub use self::maxpool2d::*;
pub use self::chain::*;
pub use self::conv2d::*;
pub use self::flatten::*;
pub use self::zeropadding2d::*;
pub use self::linear::*;
pub use self::maxpool2d::*;
pub use self::relu::*;
pub use self::sigmoid::*;
pub use self::softmax::*;
pub use self::zeropadding2d::*;

View File

@ -1,59 +1,62 @@
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendReLu};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
#[derive(Default)]
pub struct ReLuConfig;
pub struct ReLu<N, B>
where B: Backend<N>,
pub struct ReLu<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
_x: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for ReLu<N, B>
where B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for ReLu<N, B>
where
B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"ReLU"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
}
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
backend.relu(&mut ctx.outputs, x);
}
#[inline]
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
backend.relu_grad(&mut ctx.deltas, &ctx.outputs, dy);
}
}
impl <N, B, O> LayerExt<N, B, O> for ReLu<N, B>
where B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for ReLu<N, B>
where
B: Backend<N> + BackendReLu<N>,
O: Optimizer<N, B>,
{
type Config = ReLuConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
ReLu {
input_shape,
_x: Default::default()
_x: Default::default(),
}
}
}
}

View File

@ -1,23 +1,25 @@
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendSigmoid};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
#[derive(Default)]
pub struct SigmoidConfig;
pub struct Sigmoid<N, B>
where B: Backend<N>,
pub struct Sigmoid<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
_x: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
where B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>
{
impl<N, B, O> Layer<N, B, O> for Sigmoid<N, B>
where
B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
@ -32,7 +34,7 @@ impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
#[inline]
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
backend.sigmoid(&mut ctx.outputs, x);
}
@ -44,16 +46,17 @@ impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
where B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
where
B: Backend<N> + BackendSigmoid<N>,
O: Optimizer<N, B>,
{
type Config = SigmoidConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Sigmoid {
input_shape,
_x: Default::default()
_x: Default::default(),
}
}
}
}

View File

@ -1,29 +1,31 @@
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendSoftmax};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
#[derive(Default)]
pub struct SoftmaxConfig;
pub struct Softmax<N, B>
where B: Backend<N>,
pub struct Softmax<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
_x: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
where B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>
{
impl<N, B, O> Layer<N, B, O> for Softmax<N, B>
where
B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"Softmax"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -44,16 +46,17 @@ impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for Softmax<N, B>
where B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for Softmax<N, B>
where
B: Backend<N> + BackendSoftmax<N>,
O: Optimizer<N, B>,
{
type Config = SoftmaxConfig;
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
Softmax {
input_shape,
_x: Default::default()
_x: Default::default(),
}
}
}
}

View File

@ -1,7 +1,7 @@
use crate::tensor::{Tensor, TensorShape};
use crate::backend::{Backend, BackendCopy};
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
use crate::optimizer::Optimizer;
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
#[derive(Default)]
@ -9,24 +9,26 @@ pub struct ZeroPadding2dConfig {
pub paddings: (u32, u32),
}
pub struct ZeroPadding2d<N, B>
where B: Backend<N>,
pub struct ZeroPadding2d<N, B>
where
B: Backend<N>,
{
input_shape: TensorShape,
config: ZeroPadding2dConfig,
_x: PhantomData<fn(N, B)>,
}
impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
impl<N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
where
B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>,
{
type Context = DefaultLayerContext<N, B>;
fn name(&self) -> &str {
"ZeroPadding2d"
}
#[inline]
fn input_shape(&self) -> TensorShape {
self.input_shape.clone()
@ -39,10 +41,10 @@ impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
TensorShape::new3d(
is[0],
is[1] + self.config.paddings.0 * 2,
is[2] + self.config.paddings.1 * 2
is[2] + self.config.paddings.1 * 2,
)
}
#[inline]
fn forward(&self, _backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
@ -56,9 +58,10 @@ impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
}
}
impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
where B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>
impl<N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
where
B: Backend<N> + BackendCopy<N>,
O: Optimizer<N, B>,
{
type Config = ZeroPadding2dConfig;
@ -66,8 +69,7 @@ impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
ZeroPadding2d {
input_shape,
config,
_x: Default::default()
_x: Default::default(),
}
}
}
}

View File

@ -1,5 +1,5 @@
#![feature(specialization, trait_alias)]
#![recursion_limit="128"]
#![recursion_limit = "128"]
pub mod layer;
pub mod layers;
@ -13,15 +13,15 @@ pub mod native;
pub mod loss;
pub mod losses;
pub mod tensor;
pub mod params;
pub mod tensor;
#[macro_use]
mod macros;
pub mod prelude {
pub use super::backend::*;
pub use super::layer::*;
pub use super::loss::*;
pub use super::tensor::*;
pub use super::layer::*;
}
}

View File

@ -3,4 +3,4 @@ use crate::backend::Backend;
pub trait Loss<N, B: Backend<N>> {
fn compute(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor);
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor);
}
}

View File

@ -1,17 +1,15 @@
use crate::loss::Loss;
use crate::backend::{Backend, BackendSub};
use crate::loss::Loss;
use core::marker::PhantomData;
pub struct CrossEntropyLoss<N, B> {
_m: PhantomData<fn(N, B)>
_m: PhantomData<fn(N, B)>,
}
impl<N, B> CrossEntropyLoss<N, B> {
pub fn new() -> Self {
Self {
_m: Default::default()
_m: Default::default(),
}
}
}
@ -21,7 +19,7 @@ impl<N, B: Backend<N> + BackendSub<N>> Loss<N, B> for CrossEntropyLoss<N, B> {
// TODO
}
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
backend.sub(dst, pred, target);
}
}
}

View File

@ -1,5 +1,5 @@
mod mse;
mod cross_entropy;
mod mse;
pub use self::cross_entropy::*;
pub use self::mse::*;
pub use self::cross_entropy::*;

View File

@ -1,24 +1,24 @@
use crate::loss::Loss;
use crate::backend::{Backend, BackendMse};
use crate::loss::Loss;
use crate::tensor::Tensor;
use core::marker::PhantomData;
pub struct MeanSquareErrorLoss<N, B> {
_m: PhantomData<fn(N, B)>
_m: PhantomData<fn(N, B)>,
}
impl<N, B> MeanSquareErrorLoss<N, B> {
pub fn new() -> Self {
Self {
_m: Default::default()
_m: Default::default(),
}
}
}
impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
where B: Backend<N> + BackendMse<N>
impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
where
B: Backend<N> + BackendMse<N>,
{
fn compute(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
let batch_size = pred.shape().get(0) as f32;
@ -31,4 +31,4 @@ impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
backend.scaled_diff(dst, pred, target, backend.scalar_f32(1.0 / batch_size));
}
}
}

View File

@ -1,4 +1,3 @@
#[macro_export]
macro_rules! sequential_type {
(input_shape: ($($shape:tt)*), layers: { $($layers:tt)* }) => {
@ -11,13 +10,13 @@ macro_rules! sequential_type_impl {
($t:ty {$($tt:tt)*}) => ($t);
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
$crate::layers::Chain<N, B, O,
$crate::layers::Chain<N, B, O,
$t, $crate::sequential_type_impl!($($tt)*)
>
};
($t:ty) => ($t);
($t:ty, $($tt:tt)*) => {
$crate::layers::Chain<N, B, O,
$crate::layers::Chain<N, B, O,
$t, $crate::sequential_type_impl!($($tt)*)
>
};
@ -33,7 +32,7 @@ macro_rules! sequential {
}
#[macro_export]
macro_rules! sequential_impl {
macro_rules! sequential_impl {
($p:expr, $t:ty { $($name:ident : $val:expr),* }) => {{
#[allow(unused_mut)]
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
@ -58,7 +57,7 @@ macro_rules! sequential_impl {
);
let prev_shape = $crate::layer::Layer::<N, B, O>::output_shape(&layer);
$crate::layers::Chain::new(
layer, $crate::sequential_impl! { prev_shape, $($tt)* },
)
@ -119,7 +118,7 @@ macro_rules! model_impl {
_m: core::marker::PhantomData<fn(N, B, O)>,
}
impl<N, B, O> $name<N, B, O>
impl<N, B, O> $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
@ -130,7 +129,7 @@ macro_rules! model_impl {
#[allow(unused_imports)]
use $crate::backend::PoolingKind::*;
Self {
inner: $crate::sequential!($($tt)*),
_m: Default::default(),
@ -138,7 +137,7 @@ macro_rules! model_impl {
}
}
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
// where B: $crate::backend::Backend<N> + $trait,
// O: $crate::optimizer::Optimizer<N, B>
// {
@ -151,7 +150,7 @@ macro_rules! model_impl {
// }
// }
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
@ -170,7 +169,7 @@ macro_rules! model_impl {
#[inline]
fn param_count(&self) -> usize {
self.inner.param_count()
}
}
#[inline]
fn input_shape(&self) -> $crate::tensor::TensorShape {
@ -206,12 +205,12 @@ macro_rules! model_impl {
writeln!(f, "{}{}[{}] {{", "", self.name(), self.param_count())?;
self.inner.fmt(f, padding + 2)?;
write!(f, "}}")?;
Ok(())
}
}
impl<N, B, O> core::fmt::Display for $name<N, B, O>
impl<N, B, O> core::fmt::Display for $name<N, B, O>
where B: $crate::backend::Backend<N> + $trait,
O: $crate::optimizer::Optimizer<N, B>
{
@ -227,7 +226,7 @@ macro_rules! model_impl {
macro_rules! model {
($name:ident ($($init:tt)*) { $($tt:tt)* }) => {
mod tmp {
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
+ $crate::backend::BackendBias<N>
+ $crate::backend::BackendScale<N>
+ $crate::backend::BackendSigmoid<N>
@ -242,4 +241,4 @@ macro_rules! model {
($name:ident <$trait:path> ($($init:tt)*) { $($tt:tt)* }) => {
$crate::model_impl!($name <$trait> ($($init)*) { $($tt)* });
};
}
}

View File

@ -1,17 +1,24 @@
#[allow(dead_code)]
pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
pub fn valid_conv2d_3x3(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - 3) / s_row + 1;
let y_cols = (x_cols - 3) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..9];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut sum = 0.0;
sum += x[(xi + 0) as usize] * w[0];
@ -19,13 +26,13 @@ pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
sum += x[(xi + 2) as usize] * w[2];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[3];
sum += x[(xi + 1) as usize] * w[4];
sum += x[(xi + 2) as usize] * w[5];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[6];
sum += x[(xi + 1) as usize] * w[7];
sum += x[(xi + 2) as usize] * w[8];
@ -35,19 +42,24 @@ pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
#[allow(dead_code)]
pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
s_row: isize, s_col: isize) {
pub fn full_xcorr2d_3x3(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_cols = (x_cols - 1) * s_col + 3;
let y_rows = (x_rows - 1) * s_row + 3;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..9];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
@ -57,12 +69,12 @@ pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
y[(yi + 1) as usize] += z * w[7];
y[(yi + 2) as usize] += z * w[6];
yi += y_cols;
y[(yi + 0) as usize] += z * w[5];
y[(yi + 1) as usize] += z * w[4];
y[(yi + 2) as usize] += z * w[3];
yi += y_cols;
y[(yi + 0) as usize] += z * w[2];
y[(yi + 1) as usize] += z * w[1];
y[(yi + 2) as usize] += z * w[0];
@ -70,24 +82,32 @@ pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
pub fn conv2d_forward_3x3(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
pub fn conv2d_forward_3x3(
y: &mut [f32],
x: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - 3) / s_row + 1;
let y_cols = (x_cols - 3) / s_col + 1;
let x_img_size = x_rows * x_cols;
let y_img_size = y_rows * y_cols;
let w_img_size = 9;
let x_batch_size = x_channels * x_img_size;
let y_batch_size = y_channels * y_img_size;
let y = &mut y[0..(bs * y_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
@ -96,50 +116,56 @@ pub fn conv2d_forward_3x3(y: &mut [f32], x: &[f32], w: &[f32],
for y_ch in 0..y_channels {
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
valid_conv2d_3x3(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
}
}
}
}
}
pub fn conv2d_backward_3x3(dx: &mut [f32], dy: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
y_rows: isize, y_cols: isize,
s_row: isize, s_col: isize) {
}
pub fn conv2d_backward_3x3(
dx: &mut [f32],
dy: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
y_rows: isize,
y_cols: isize,
s_row: isize,
s_col: isize,
) {
let x_cols = (y_cols - 1) * s_col + 3;
let x_rows = (y_rows - 1) * s_row + 3;
let dx_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let w_img_size = 9;
let dx_batch_size = x_channels * dx_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
for x_ch in 0..x_channels {
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
full_xcorr2d_3x3(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
}
}
}
}
}
}

View File

@ -1,17 +1,24 @@
#[allow(dead_code)]
pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
pub fn valid_conv2d_5x5(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - 5) / s_row + 1;
let y_cols = (x_cols - 5) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..25];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut sum = 0.0;
sum += x[(xi + 0) as usize] * w[0];
@ -20,28 +27,28 @@ pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
sum += x[(xi + 3) as usize] * w[3];
sum += x[(xi + 4) as usize] * w[4];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[5];
sum += x[(xi + 1) as usize] * w[6];
sum += x[(xi + 2) as usize] * w[7];
sum += x[(xi + 3) as usize] * w[8];
sum += x[(xi + 4) as usize] * w[9];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[10];
sum += x[(xi + 1) as usize] * w[11];
sum += x[(xi + 2) as usize] * w[12];
sum += x[(xi + 3) as usize] * w[13];
sum += x[(xi + 4) as usize] * w[14];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[15];
sum += x[(xi + 1) as usize] * w[16];
sum += x[(xi + 2) as usize] * w[17];
sum += x[(xi + 3) as usize] * w[18];
sum += x[(xi + 4) as usize] * w[19];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[20];
sum += x[(xi + 1) as usize] * w[21];
sum += x[(xi + 2) as usize] * w[22];
@ -53,19 +60,24 @@ pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
#[allow(dead_code)]
pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
s_row: isize, s_col: isize) {
pub fn full_xcorr2d_5x5(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_cols = (x_cols - 1) * s_col + 5;
let y_rows = (x_rows - 1) * s_row + 5;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..25];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
@ -77,28 +89,28 @@ pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
y[(yi + 3) as usize] += z * w[21];
y[(yi + 4) as usize] += z * w[20];
yi += y_cols;
y[(yi + 0) as usize] += z * w[19];
y[(yi + 1) as usize] += z * w[18];
y[(yi + 2) as usize] += z * w[17];
y[(yi + 3) as usize] += z * w[16];
y[(yi + 4) as usize] += z * w[15];
yi += y_cols;
y[(yi + 0) as usize] += z * w[14];
y[(yi + 1) as usize] += z * w[13];
y[(yi + 2) as usize] += z * w[12];
y[(yi + 3) as usize] += z * w[11];
y[(yi + 4) as usize] += z * w[10];
yi += y_cols;
y[(yi + 0) as usize] += z * w[9];
y[(yi + 1) as usize] += z * w[8];
y[(yi + 2) as usize] += z * w[7];
y[(yi + 3) as usize] += z * w[6];
y[(yi + 4) as usize] += z * w[5];
yi += y_cols;
y[(yi + 0) as usize] += z * w[4];
y[(yi + 1) as usize] += z * w[3];
y[(yi + 2) as usize] += z * w[2];
@ -108,24 +120,32 @@ pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
pub fn conv2d_forward_5x5(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
pub fn conv2d_forward_5x5(
y: &mut [f32],
x: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
x_rows: isize,
x_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - 5) / s_row + 1;
let y_cols = (x_cols - 5) / s_col + 1;
let x_img_size = x_rows * x_cols;
let y_img_size = y_rows * y_cols;
let w_img_size = 25;
let x_batch_size = x_channels * x_img_size;
let y_batch_size = y_channels * y_img_size;
let y = &mut y[0..(bs * y_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
@ -134,50 +154,56 @@ pub fn conv2d_forward_5x5(y: &mut [f32], x: &[f32], w: &[f32],
for y_ch in 0..y_channels {
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
valid_conv2d_5x5(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
}
}
}
}
}
pub fn conv2d_backward_5x5(dx: &mut [f32], dy: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
y_rows: isize, y_cols: isize,
s_row: isize, s_col: isize) {
}
pub fn conv2d_backward_5x5(
dx: &mut [f32],
dy: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
y_rows: isize,
y_cols: isize,
s_row: isize,
s_col: isize,
) {
let x_cols = (y_cols - 1) * s_col + 5;
let x_rows = (y_rows - 1) * s_row + 5;
let dx_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let w_img_size = 25;
let dx_batch_size = x_channels * dx_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
for x_ch in 0..x_channels {
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
full_xcorr2d_5x5(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
}
}
}
}
}
}

View File

@ -5,30 +5,36 @@ pub use self::kernel_3x3::*;
pub use self::kernel_5x5::*;
#[allow(dead_code)]
pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn valid_conv2d(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - w_rows) / s_row + 1;
let y_cols = (x_cols - w_cols) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..(w_rows * w_cols) as usize];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut wi = 0;
let mut sum = 0.0;
for _ in 0..w_rows {
for w_x in 0..w_cols {
sum += x[(xi + w_x) as usize] * w[(wi + w_x) as usize];
}
xi += x_cols;
wi += w_cols;
}
@ -39,30 +45,36 @@ pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
#[allow(dead_code)]
pub fn valid_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn valid_xcorr2d(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - w_rows) / s_row + 1;
let y_cols = (x_cols - w_cols) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..(w_rows * w_cols) as usize];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut wi = w_rows * w_cols - 1;
let mut sum = 0.0;
for _ in 0..w_rows {
for w_x in 0..w_cols {
sum += x[(xi + w_x) as usize] * w[(wi - w_x) as usize];
}
xi += x_cols;
wi -= w_cols;
}
@ -73,29 +85,36 @@ pub fn valid_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
#[allow(dead_code)]
pub fn full_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn full_conv2d(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_cols = (x_cols - 1) * s_col + w_cols;
let y_rows = (x_rows - 1) * s_row + w_rows;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..(w_rows * w_cols) as usize];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
let mut wi = 0;
let z = alpha * x[(x_y * x_cols + x_x) as usize];
for _ in 0..w_rows {
for w_x in 0..w_cols {
y[(yi + w_x) as usize] += z * w[(wi + w_x) as usize];
}
yi += y_cols;
wi += w_cols;
}
@ -104,29 +123,36 @@ pub fn full_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
#[allow(dead_code)]
pub fn full_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn full_xcorr2d(
y: &mut [f32],
x: &[f32],
w: &[f32],
alpha: f32,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_cols = (x_cols - 1) * s_col + w_cols;
let y_rows = (x_rows - 1) * s_row + w_rows;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..(w_rows * w_cols) as usize];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
let mut wi = w_rows * w_cols - 1;
let z = alpha * x[(x_y * x_cols + x_x) as usize];
for _ in 0..w_rows {
for w_x in 0..w_cols {
y[(yi + w_x) as usize] += z * w[(wi - w_x) as usize];
}
yi += y_cols;
wi -= w_cols;
}
@ -134,27 +160,34 @@ pub fn full_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
pub fn conv2d_forward(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn conv2d_forward(
y: &mut [f32],
x: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y_rows = (x_rows - w_rows) / s_row + 1;
let y_cols = (x_cols - w_cols) / s_col + 1;
let x_img_size = x_rows * x_cols;
let y_img_size = y_rows * y_cols;
let w_img_size = w_rows * w_cols;
let x_batch_size = x_channels * x_img_size;
let y_batch_size = y_channels * y_img_size;
let y = &mut y[0..(bs * y_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
@ -163,89 +196,111 @@ pub fn conv2d_forward(y: &mut [f32], x: &[f32], w: &[f32],
for y_ch in 0..y_channels {
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
valid_conv2d(y_img, x_img, w, 1.0, x_rows, x_cols, w_rows, w_cols, s_row, s_col);
valid_conv2d(
y_img, x_img, w, 1.0, x_rows, x_cols, w_rows, w_cols, s_row, s_col,
);
}
}
}
}
}
pub fn conv2d_backward(dx: &mut [f32], dy: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
y_rows: isize, y_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
}
pub fn conv2d_backward(
dx: &mut [f32],
dy: &[f32],
w: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
y_rows: isize,
y_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let x_cols = (y_cols - 1) * s_col + w_cols;
let x_rows = (y_rows - 1) * s_row + w_rows;
let dx_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let w_img_size = w_rows * w_cols;
let dx_batch_size = x_channels * dx_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
for x_ch in 0..x_channels {
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
full_xcorr2d(dx_img, dy_img, w, 1.0, y_rows, y_cols, w_rows, w_cols, s_row, s_col);
full_xcorr2d(
dx_img, dy_img, w, 1.0, y_rows, y_cols, w_rows, w_cols, s_row, s_col,
);
}
}
}
}
}
pub fn conv2d_grads(dw: &mut [f32], x: &[f32], dy: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize,
y_rows: isize, y_cols: isize,
s_row: isize, s_col: isize) {
pub fn conv2d_grads(
dw: &mut [f32],
x: &[f32],
dy: &[f32],
bs: isize,
x_channels: isize,
y_channels: isize,
x_rows: isize,
x_cols: isize,
y_rows: isize,
y_cols: isize,
s_row: isize,
s_col: isize,
) {
let w_cols = x_cols - y_cols + 1;
let w_rows = x_rows - y_rows + 1;
let x_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let dw_img_size = w_rows * w_cols;
let x_batch_size = x_channels * x_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dw = &mut dw[0..(y_channels * dw_img_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
let x_img = &x[x_offset..x_offset + x_img_size as usize];
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
let dw_offset = (y_ch * dw_img_size) as usize;
let dw = &mut dw[dw_offset..dw_offset + dw_img_size as usize];
valid_conv2d(dw, x_img, dy_img, 1.0, x_rows, x_cols, y_rows, y_cols, s_row, s_col);
valid_conv2d(
dw, x_img, dy_img, 1.0, x_rows, x_cols, y_rows, y_cols, s_row, s_col,
);
}
}
}
}
}
@ -256,198 +311,136 @@ mod test {
#[test]
fn test_valid_conv2d() {
let x: &[f32] = &[
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7,
2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7,
3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7,
4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7,
6.0, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7,
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.0,
2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 4.0, 4.1,
4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0, 6.1, 6.2,
6.3, 6.4, 6.5, 6.6, 6.7,
];
let w: &[f32] = &[
0.1, 0.2, 0.3,
0.4, 0.5, 0.6,
0.7, 0.8, 0.9,
];
let w: &[f32] = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
valid_conv2d(y, x, w, 1.0, 7, 8, 3, 3, 1, 1);
assert_eq!(y, &[
6.81, 7.26, 7.71, 8.16, 8.610001, 9.06,
11.31, 11.76, 12.210001, 12.66, 13.11, 13.56,
15.809999, 16.26, 16.710001, 17.160002, 17.61, 18.06,
20.31, 20.76, 21.210001, 21.66, 22.11, 22.559998,
24.81, 25.26, 25.710001, 26.160002, 26.61, 27.060001
])
assert_eq!(
y,
&[
6.81, 7.26, 7.71, 8.16, 8.610001, 9.06, 11.31, 11.76, 12.210001, 12.66, 13.11,
13.56, 15.809999, 16.26, 16.710001, 17.160002, 17.61, 18.06, 20.31, 20.76,
21.210001, 21.66, 22.11, 22.559998, 24.81, 25.26, 25.710001, 26.160002, 26.61,
27.060001
]
)
}
#[test]
fn test_valid_xcorr2d() {
let x: &[f32] = &[
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7,
2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7,
3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7,
4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7,
6.0, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7,
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.0,
2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 4.0, 4.1,
4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0, 6.1, 6.2,
6.3, 6.4, 6.5, 6.6, 6.7,
];
let w: &[f32] = &[
0.1, 0.2, 0.3,
0.4, 0.5, 0.6,
0.7, 0.8, 0.9,
];
let w: &[f32] = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
valid_xcorr2d(y, x, w, 1.0, 7, 8, 3, 3, 1, 1);
assert_eq!(y, &[
3.0900004, 3.54, 3.9900002, 4.44, 4.8900003, 5.34,
7.59, 8.04, 8.490001, 8.940001, 9.389999, 9.84,
12.089999, 12.540001, 12.989999, 13.44, 13.889998, 14.340001,
16.59, 17.039999, 17.490002, 17.939999, 18.39, 18.84,
21.09, 21.539999, 21.99, 22.44, 22.89, 23.34
])
assert_eq!(
y,
&[
3.0900004, 3.54, 3.9900002, 4.44, 4.8900003, 5.34, 7.59, 8.04, 8.490001, 8.940001,
9.389999, 9.84, 12.089999, 12.540001, 12.989999, 13.44, 13.889998, 14.340001,
16.59, 17.039999, 17.490002, 17.939999, 18.39, 18.84, 21.09, 21.539999, 21.99,
22.44, 22.89, 23.34
]
)
}
#[test]
fn test_full_conv2d() {
let x: &[f32] = &[
1.0, 2.0,
3.0, 4.0,
];
let w: &[f32] = &[
1.0, 2.0, 1.0,
3.0, 2.0, 1.0,
1.0, 4.0, 3.0,
];
let x: &[f32] = &[1.0, 2.0, 3.0, 4.0];
let w: &[f32] = &[1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 1.0, 4.0, 3.0];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
];
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
full_conv2d(y, x, w, 1.0, 2, 2, 3, 3, 1, 1);
assert_eq!(y, &[
1.0, 4.0, 5.0, 2.0,
6.0, 18.0, 16.0, 6.0,
10.0, 24.0, 22.0, 10.0,
3.0, 16.0, 25.0, 12.0,
])
assert_eq!(
y,
&[
1.0, 4.0, 5.0, 2.0, 6.0, 18.0, 16.0, 6.0, 10.0, 24.0, 22.0, 10.0, 3.0, 16.0, 25.0,
12.0,
]
)
}
#[test]
fn test_full_xcorr2d() {
let x: &[f32] = &[
1.0, 2.0,
3.0, 4.0,
];
let w: &[f32] = &[
1.0, 2.0, 1.0,
3.0, 2.0, 1.0,
1.0, 4.0, 3.0,
];
let x: &[f32] = &[1.0, 2.0, 3.0, 4.0];
let w: &[f32] = &[1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 1.0, 4.0, 3.0];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
];
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
full_xcorr2d(y, x, w, 1.0, 2, 2, 3, 3, 1, 1);
assert_eq!(y, &[
3.0, 10.0, 9.0, 2.0,
10.0, 28.0, 26.0, 10.0,
4.0, 14.0, 22.0, 14.0,
3.0, 10.0, 11.0, 4.0,
])
assert_eq!(
y,
&[
3.0, 10.0, 9.0, 2.0, 10.0, 28.0, 26.0, 10.0, 4.0, 14.0, 22.0, 14.0, 3.0, 10.0,
11.0, 4.0,
]
)
}
#[test]
fn test_conv2d_grads() {
let x: &[f32] = &[
0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0,
1.0, 1.0, 1.0, 1.0, 1.0,
0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
];
let true_y: &[f32] = &[
0.0, 1.0, 0.0,
0.0, 1.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 0.0,
1.0, 1.0, 1.0,
0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0,
0.0,
];
let w: &mut [f32] = &mut [
0.02, -0.01, 0.01,
-0.0012, 0.001, -0.005,
0.021, -0.0001, 0.008,
0.021, -0.0001, 0.008,
-0.0012, 0.001, -0.005,
0.02, -0.01, 0.01,
0.02, -0.01, 0.01, -0.0012, 0.001, -0.005, 0.021, -0.0001, 0.008, 0.021, -0.0001,
0.008, -0.0012, 0.001, -0.005, 0.02, -0.01, 0.01,
];
for _ in 0..10 {
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0,
];
let dw: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0,
];
let dy: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0,
];
conv2d_forward(y, x, w, 1, 1, 2, 5, 5, 3, 3, 1, 1);
for i in 0..y.len() {
@ -455,20 +448,34 @@ mod test {
}
conv2d_grads(dw, x, dy, 1, 1, 2, 5, 5, 3, 3, 1, 1);
for i in 0..w.len() {
w[i] -= dw[i] * 0.01;
}
}
assert_eq!(w, &[
0.018572275, 0.16292913, 0.011668432,
0.0014115912, 0.17488956, 0.00011491659,
0.018765995, 0.17117187, 0.009149003,
0.018765992, 0.0035881882, 0.009149002,
0.16899528, 0.17488956, 0.16769859,
0.018572278, -0.004654532, 0.011668428
]);
assert_eq!(
w,
&[
0.018572275,
0.16292913,
0.011668432,
0.0014115912,
0.17488956,
0.00011491659,
0.018765995,
0.17117187,
0.009149003,
0.018765992,
0.0035881882,
0.009149002,
0.16899528,
0.17488956,
0.16769859,
0.018572278,
-0.004654532,
0.011668428
]
);
}
}
}

View File

@ -1,9 +1,15 @@
fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
c: &mut [f32], ldc: usize)
{
fn gemm_nn(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
lda: usize,
b: &[f32],
ldb: usize,
c: &mut [f32],
ldc: usize,
) {
let a = &a[0..m * k];
let b = &b[0..n * k];
let c = &mut c[0..m * n];
@ -18,11 +24,18 @@ fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
}
}
fn gemm_nt(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
c: &mut [f32], ldc: usize)
{
fn gemm_nt(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
lda: usize,
b: &[f32],
ldb: usize,
c: &mut [f32],
ldc: usize,
) {
let a = &a[0..m * k];
let b = &b[0..n * k];
let c = &mut c[0..m * n];
@ -40,11 +53,18 @@ fn gemm_nt(m: usize, n: usize, k: usize, alpha: f32,
}
}
fn gemm_tn(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
c: &mut [f32], ldc: usize)
{
fn gemm_tn(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
lda: usize,
b: &[f32],
ldb: usize,
c: &mut [f32],
ldc: usize,
) {
let a = &a[0..m * k];
let b = &b[0..n * k];
let c = &mut c[0..m * n];
@ -60,19 +80,26 @@ fn gemm_tn(m: usize, n: usize, k: usize, alpha: f32,
}
}
fn gemm_tt(m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize,
c: &mut [f32], ldc: usize)
{
fn gemm_tt(
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
lda: usize,
b: &[f32],
ldb: usize,
c: &mut [f32],
ldc: usize,
) {
let a = &a[0..m * k];
let b = &b[0..n * k];
let c = &mut c[0..m * n];
for i_m in 0..m {
for i_n in 0..n {
let mut sum = 0.0;
for i_k in 0..k {
sum += alpha * a[i_k * lda + i_m] * b[i_n * ldb + i_k];
}
@ -82,11 +109,21 @@ fn gemm_tt(m: usize, n: usize, k: usize, alpha: f32,
}
}
pub fn gemm(ta: bool, tb: bool, m: usize, n: usize, k: usize, alpha: f32,
a: &[f32], lda: usize,
b: &[f32], ldb: usize, beta: f32,
c: &mut [f32], ldc: usize)
{
pub fn gemm(
ta: bool,
tb: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
lda: usize,
b: &[f32],
ldb: usize,
beta: f32,
c: &mut [f32],
ldc: usize,
) {
for i in 0..m {
for j in 0..n {
c[i * ldc + j] *= beta;
@ -102,4 +139,4 @@ pub fn gemm(ta: bool, tb: bool, m: usize, n: usize, k: usize, alpha: f32,
} else {
gemm_tt(m, n, k, alpha, a, lda, b, ldb, c, ldc);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,22 @@
pub fn maxpool2d(y: &mut [f32], x: &[f32],
y_rows: isize, y_cols: isize,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn maxpool2d(
y: &mut [f32],
x: &[f32],
y_rows: isize,
y_cols: isize,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut max = core::f32::NEG_INFINITY;
for _ in 0..w_rows {
for w_x in 0..w_cols {
@ -19,25 +25,32 @@ pub fn maxpool2d(y: &mut [f32], x: &[f32],
max = val;
}
}
xi += x_cols;
}
y[(y_y * y_cols + y_x) as usize] = max;
}
}
}
pub fn maxpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
x_rows: isize, x_cols: isize,
y_rows: isize, y_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize)
{
pub fn maxpool2d_backward(
dx: &mut [f32],
x: &[f32],
dy: &[f32],
x_rows: isize,
x_cols: isize,
y_rows: isize,
y_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let dx = &mut dx[0..(x_rows * x_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let dy = &dy[0..(y_rows * y_cols) as usize];
for dy_y in 0..y_rows {
for dy_x in 0..y_cols {
let mut xi = s_row * dy_y * x_cols + s_col * dy_x;
@ -55,54 +68,67 @@ pub fn maxpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
}
xi += x_cols;
}
dx[max_idx as usize] = dy[(dy_y * y_cols + dy_x) as usize];
}
}
}
#[allow(dead_code)]
pub fn avgpool2d(y: &mut [f32], x: &[f32],
y_rows: isize, y_cols: isize,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
pub fn avgpool2d(
y: &mut [f32],
x: &[f32],
y_rows: isize,
y_cols: isize,
x_rows: isize,
x_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let w_size = w_rows * w_cols;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut sum = 0.0;
for _ in 0..w_rows {
for w_x in 0..w_cols {
sum += x[(xi + w_x) as usize];
}
xi += x_cols;
}
y[(y_y * y_cols + y_x) as usize] = sum / w_size as f32;
}
}
}
#[allow(dead_code)]
pub fn avgpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
x_rows: isize, x_cols: isize,
y_rows: isize, y_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize)
{
pub fn avgpool2d_backward(
dx: &mut [f32],
x: &[f32],
dy: &[f32],
x_rows: isize,
x_cols: isize,
y_rows: isize,
y_cols: isize,
w_rows: isize,
w_cols: isize,
s_row: isize,
s_col: isize,
) {
let dx = &mut dx[0..(x_rows * x_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let dy = &dy[0..(y_rows * y_cols) as usize];
for dy_y in 0..y_rows {
for dy_x in 0..y_cols {
let mut xi = s_row * dy_y * x_cols + s_col * dy_x;
@ -120,13 +146,12 @@ pub fn avgpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
}
xi += x_cols;
}
dx[max_idx as usize] = dy[(dy_y * y_cols + dy_x) as usize];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -134,65 +159,41 @@ mod tests {
#[test]
fn test_maxpool2d() {
let x: &[f32] = &[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
];
let y: &mut [f32] = &mut [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
maxpool2d(y, x, 3, 3, 6, 6, 2, 2, 2, 2);
assert_eq!(y, &[
8.0, 10.0, 12.0,
20.0, 22.0, 24.0,
32.0, 34.0, 36.0,
])
assert_eq!(y, &[8.0, 10.0, 12.0, 20.0, 22.0, 24.0, 32.0, 34.0, 36.0,])
}
#[test]
fn test_maxpool2d_backward() {
let x: &[f32] = &[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
];
let dy: &[f32] = &[
9.0, 8.0, 7.0,
6.0, 5.0, 4.0,
3.0, 2.0, 1.0
];
let dy: &[f32] = &[9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let dx: &mut [f32] = &mut [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0,
];
maxpool2d_backward(dx, x, dy, 6, 6, 3, 3, 2, 2, 2, 2);
let tt: &[f32] = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 9.0, 0.0, 8.0, 0.0, 7.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 6.0, 0.0, 5.0, 0.0, 4.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 3.0, 0.0, 2.0, 0.0, 1.0
];
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 8.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 6.0, 0.0, 5.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 2.0,
0.0, 1.0,
];
assert_eq!(dx, tt);
}
@ -200,26 +201,15 @@ mod tests {
#[test]
fn test_avgpool2d() {
let x: &[f32] = &[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
];
let y: &mut [f32] = &mut [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
avgpool2d(y, x, 3, 3, 6, 6, 2, 2, 2, 2);
assert_eq!(y, &[
4.5, 6.5, 8.5,
16.5, 18.5, 20.5,
28.5, 30.5, 32.5,
])
assert_eq!(y, &[4.5, 6.5, 8.5, 16.5, 18.5, 20.5, 28.5, 30.5, 32.5,])
}
}
}

View File

@ -1,6 +1,5 @@
use crate::tensor::TensorShape;
use crate::backend::Backend;
use crate::tensor::TensorShape;
pub trait OptimizerContext {
fn new<S: Into<TensorShape>>(shape: S) -> Self;
@ -9,14 +8,26 @@ pub trait OptimizerContext {
pub trait Optimizer<N, B: Backend<N>> {
type Context: OptimizerContext;
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor);
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
);
}
impl <'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
impl<'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
type Context = O::Context;
#[inline]
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
) {
(**self).update_params(backend, ctx, params, grads)
}
}
}

View File

@ -1,12 +1,12 @@
use crate::backend::{Backend, BackendAdam};
use crate::optimizer::{Optimizer, OptimizerContext};
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
use core::cell::Cell;
use core::marker::PhantomData;
pub struct AdamContext<N, B>
where B: Backend<N>
pub struct AdamContext<N, B>
where
B: Backend<N>,
{
moms: B::Tensor,
vels: B::Tensor,
@ -34,11 +34,12 @@ pub struct Adam<N, B: Backend<N>> {
epsilon: Option<f32>,
amsgrad: bool,
iteration: Cell<f32>,
_m: PhantomData<fn(N, B)>,
_m: PhantomData<fn(N, B)>,
}
impl<N, B> Default for Adam<N, B>
where B: Backend<N>
impl<N, B> Default for Adam<N, B>
where
B: Backend<N>,
{
fn default() -> Self {
Self {
@ -70,12 +71,19 @@ impl<N, B: Backend<N>> Adam<N, B> {
impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for Adam<N, B> {
type Context = AdamContext<N, B>;
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
) {
let iter = self.iteration.get();
let t = iter + 1.0;
self.iteration.set(iter + 0.25);
let lr_t = self.learning_rate * ((1.0 - self.beta_2.powf(t)).sqrt() / (1.0 - self.beta_1.powf(t)));
let lr_t =
self.learning_rate * ((1.0 - self.beta_2.powf(t)).sqrt() / (1.0 - self.beta_1.powf(t)));
// m_t = (self.beta_1 * m) + (1. - self.beta_1) * g;
backend.scale(&mut ctx.moms, backend.scalar_f32(self.beta_1));
@ -87,10 +95,22 @@ impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for Adam<N, B> {
if self.amsgrad {
backend.maximum(&mut ctx.vhats, &ctx.vels);
backend.adam_p(params, backend.scalar_f32(-lr_t), &ctx.moms, &ctx.vhats, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
backend.adam_p(
params,
backend.scalar_f32(-lr_t),
&ctx.moms,
&ctx.vhats,
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
);
} else {
// p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
backend.adam_p(params, backend.scalar_f32(-lr_t), &ctx.moms, &ctx.vels, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
backend.adam_p(
params,
backend.scalar_f32(-lr_t),
&ctx.moms,
&ctx.vels,
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
);
}
}
}

View File

@ -1,18 +1,19 @@
mod sgd;
mod adam;
mod rmsprop;
mod sgd;
pub use self::sgd::*;
pub use self::adam::*;
pub use self::rmsprop::*;
pub use self::sgd::*;
use crate::backend::{Backend, BackendAxpys};
use crate::optimizer::Optimizer;
use core::marker::PhantomData;
pub struct WeightDecay<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
lamda: f32,
optimizer: O,
@ -20,8 +21,9 @@ pub struct WeightDecay<N, B, O>
}
impl<N, B, O> WeightDecay<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
where
B: Backend<N>,
O: Optimizer<N, B>,
{
pub fn new(lamda: f32, optimizer: O) -> Self {
Self {
@ -32,14 +34,21 @@ impl<N, B, O> WeightDecay<N, B, O>
}
}
impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O>
where B: Backend<N> + BackendAxpys<N>,
O: Optimizer<N, B>
impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O>
where
B: Backend<N> + BackendAxpys<N>,
O: Optimizer<N, B>,
{
type Context = O::Context;
#[inline]
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
) {
backend.axpys(grads, backend.scalar_f32(self.lamda), params);
self.optimizer.update_params(backend, ctx, params, grads);

View File

@ -3,9 +3,9 @@ use crate::optimizer::{Optimizer, OptimizerContext};
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
pub struct RMSPropContext<N, B>
where B: Backend<N>
pub struct RMSPropContext<N, B>
where
B: Backend<N>,
{
accum: B::Tensor,
_m: PhantomData<fn(N, B)>,
@ -26,11 +26,12 @@ pub struct RMSProp<N, B: Backend<N>> {
learning_rate: f32,
rho: f32,
epsilon: Option<f32>,
_m: PhantomData<fn(N, B)>,
_m: PhantomData<fn(N, B)>,
}
impl<N, B> Default for RMSProp<N, B>
where B: Backend<N>
impl<N, B> Default for RMSProp<N, B>
where
B: Backend<N>,
{
fn default() -> Self {
Self {
@ -56,12 +57,24 @@ impl<N, B: Backend<N>> RMSProp<N, B> {
impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for RMSProp<N, B> {
type Context = RMSPropContext<N, B>;
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
) {
// new_a = self.rho * a + (1. - self.rho) * K.square(g)
backend.scale(&mut ctx.accum, backend.scalar_f32(self.rho));
backend.axpys(&mut ctx.accum, backend.scalar_f32(1.0 - self.rho), grads);
// new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
backend.adam_p(params, backend.scalar_f32(-self.learning_rate), &grads, &ctx.accum, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
backend.adam_p(
params,
backend.scalar_f32(-self.learning_rate),
&grads,
&ctx.accum,
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
);
}
}

View File

@ -1,11 +1,11 @@
use crate::backend::{Backend, BackendScale, BackendAxpy, BackendAdd};
use crate::backend::{Backend, BackendAdd, BackendAxpy, BackendScale};
use crate::optimizer::{Optimizer, OptimizerContext};
use crate::tensor::{Tensor, TensorShape};
use core::marker::PhantomData;
pub struct SgdContext<N, B>
where B: Backend<N>
pub struct SgdContext<N, B>
where
B: Backend<N>,
{
moments: B::Tensor,
_m: PhantomData<fn(N, B)>,
@ -24,11 +24,12 @@ pub struct Sgd<N, B: Backend<N>> {
learning_rate: f32,
momentum: f32,
nesterov: bool,
_m: PhantomData<fn(N, B)>,
_m: PhantomData<fn(N, B)>,
}
impl<N, B> Default for Sgd<N, B>
where B: Backend<N>
impl<N, B> Default for Sgd<N, B>
where
B: Backend<N>,
{
fn default() -> Self {
Self {
@ -40,8 +41,9 @@ impl<N, B> Default for Sgd<N, B>
}
}
impl<N, B> Sgd<N, B>
where B: Backend<N>
impl<N, B> Sgd<N, B>
where
B: Backend<N>,
{
pub fn new(learning_rate: f32, momentum: f32, nesterov: bool) -> Self {
Self {
@ -53,13 +55,25 @@ impl<N, B> Sgd<N, B>
}
}
impl<N, B: Backend<N> + BackendScale<N> + BackendAxpy<N> + BackendAdd<N>> Optimizer<N, B> for Sgd<N, B> {
impl<N, B: Backend<N> + BackendScale<N> + BackendAxpy<N> + BackendAdd<N>> Optimizer<N, B>
for Sgd<N, B>
{
type Context = SgdContext<N, B>;
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
fn update_params(
&self,
backend: &B,
ctx: &mut Self::Context,
params: &mut B::Tensor,
grads: &mut B::Tensor,
) {
// m = momentum * m - lr * grads
backend.scale(&mut ctx.moments, backend.scalar_f32(self.momentum));
backend.axpy(&mut ctx.moments, backend.scalar_f32(-self.learning_rate), grads);
backend.axpy(
&mut ctx.moments,
backend.scalar_f32(-self.learning_rate),
grads,
);
if self.nesterov {
// p += momentum * m - lr * grads

View File

@ -1,6 +1,6 @@
use crate::backend::Backend;
use crate::optimizer::{Optimizer, OptimizerContext};
use crate::tensor::{TensorShape, Tensor};
use crate::tensor::{Tensor, TensorShape};
pub struct Params<N, B: Backend<N>, O: Optimizer<N, B>> {
pub params: B::Tensor,

View File

@ -1,11 +1,10 @@
use core::fmt;
pub struct TensorShapeIter<'a> {
shape: &'a TensorShape,
left: usize,
right: usize,
}
}
impl<'a> Iterator for TensorShapeIter<'a> {
type Item = u32;
@ -50,7 +49,7 @@ pub struct TensorShape {
impl fmt::Display for TensorShape {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for i in 0 .. self.dims {
for i in 0..self.dims {
if i != 0 {
write!(f, ", ")?;
}
@ -63,7 +62,7 @@ impl fmt::Display for TensorShape {
}
}
impl TensorShape {
impl TensorShape {
#[inline]
pub fn zero() -> Self {
TensorShape {
@ -71,7 +70,7 @@ impl TensorShape {
dims: 0,
}
}
#[inline]
pub fn new0d() -> Self {
TensorShape {
@ -79,7 +78,7 @@ impl TensorShape {
dims: 0,
}
}
#[inline]
pub fn new1d(w: u32) -> Self {
TensorShape {
@ -87,7 +86,7 @@ impl TensorShape {
dims: 1,
}
}
#[inline]
pub fn new2d(h: u32, w: u32) -> Self {
TensorShape {
@ -95,7 +94,7 @@ impl TensorShape {
dims: 2,
}
}
#[inline]
pub fn new3d(b: u32, h: u32, w: u32) -> Self {
TensorShape {
@ -103,7 +102,7 @@ impl TensorShape {
dims: 3,
}
}
#[inline]
pub fn new4d(b: u32, c: u32, h: u32, w: u32) -> Self {
TensorShape {
@ -111,7 +110,7 @@ impl TensorShape {
dims: 4,
}
}
#[inline]
pub fn iter(&self) -> TensorShapeIter<'_> {
TensorShapeIter {
@ -125,7 +124,7 @@ impl TensorShape {
let s = s.borrow();
let sd = self.dims;
for i in 0 .. s.dims {
for i in 0..s.dims {
self.shape[i + sd] = s.shape[i];
}
@ -133,12 +132,12 @@ impl TensorShape {
self
}
#[inline]
pub fn get(&self, index: usize) -> u32 {
self.shape[index]
}
#[inline]
pub fn set(&mut self, index: usize, val: u32) {
self.shape[index] = val;
@ -155,17 +154,14 @@ impl TensorShape {
}
}
TensorShape {
shape,
dims
}
TensorShape { shape, dims }
}
#[inline]
pub fn size(&self) -> usize {
let mut product = 1;
for i in 0 .. self.dims {
for i in 0..self.dims {
product *= self.shape[i] as usize;
}
@ -176,21 +172,24 @@ impl TensorShape {
let mut strides = [0; 4];
let mut product = 1;
for i in 0..self.dims {
for i in 0..self.dims {
let si = self.dims - i - 1;
strides[si] = product;
product *= self.shape[si];
product *= self.shape[si];
}
TensorShape { shape: strides, dims: self.dims }
TensorShape {
shape: strides,
dims: self.dims,
}
}
#[inline]
pub fn as_slice(&self) -> &[u32] {
&self.shape[0..self.dims]
}
#[inline]
pub fn last_axis(&self) -> u32 {
self.shape[self.dims - 1]
@ -206,8 +205,8 @@ impl From<()> for TensorShape {
}
}
impl From<(u32, )> for TensorShape {
fn from(x: (u32, )) -> Self {
impl From<(u32,)> for TensorShape {
fn from(x: (u32,)) -> Self {
TensorShape {
shape: [x.0, 0, 0, 0],
dims: 1,
@ -227,7 +226,7 @@ impl From<(u32, u32)> for TensorShape {
impl From<(u32, u32, u32)> for TensorShape {
fn from(x: (u32, u32, u32)) -> Self {
TensorShape {
shape: [x.0 , x.1, x.2, 0],
shape: [x.0, x.1, x.2, 0],
dims: 3,
}
}
@ -236,7 +235,7 @@ impl From<(u32, u32, u32)> for TensorShape {
impl From<(u32, u32, u32, u32)> for TensorShape {
fn from(x: (u32, u32, u32, u32)) -> Self {
TensorShape {
shape: [x.0 , x.1, x.2, x.3],
shape: [x.0, x.1, x.2, x.3],
dims: 4,
}
}