This commit is contained in:
parent
765a1eccfa
commit
2fa6098f65
14
.drone.yml
Normal file
14
.drone.yml
Normal file
@ -0,0 +1,14 @@
|
||||
kind: pipeline
|
||||
name: default
|
||||
|
||||
steps:
|
||||
- name: build
|
||||
image: rustlang/rust:nightly
|
||||
commands:
|
||||
- cargo build --verbose --all
|
||||
|
||||
- name: fmt-check
|
||||
image: rustlang/rust:nightly
|
||||
commands:
|
||||
- rustup component add rustfmt
|
||||
- cargo fmt --all -- --check
|
417
Cargo.lock
generated
417
Cargo.lock
generated
@ -1,514 +1,363 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "0.1.4"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace-sys"
|
||||
version = "0.1.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "blas"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
|
||||
dependencies = [
|
||||
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"blas-sys",
|
||||
"libc",
|
||||
"num-complex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blas-sys"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
dependencies = [
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "2.5.0"
|
||||
version = "3.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.3.2"
|
||||
version = "1.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "c2-chacha"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "cblas"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
|
||||
dependencies = [
|
||||
"cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cblas-sys",
|
||||
"libc",
|
||||
"num-complex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cblas-sys"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
dependencies = [
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.9"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "failure"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "failure_derive"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.6"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||
dependencies = [
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.25"
|
||||
version = "0.3.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04"
|
||||
dependencies = [
|
||||
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.3.0"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"spin 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.59"
|
||||
version = "0.2.120"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad5c14e80759d0939d013e6ca49930e59fc53dd8e5009132f76240c179380c09"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.7"
|
||||
version = "0.4.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "mnist"
|
||||
version = "0.4.0"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fd2e5236a5f13d41d4f6eba34f535d037cb9fc4b244896886a7ebe0764142a5"
|
||||
dependencies = [
|
||||
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "4.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.3"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||
dependencies = [
|
||||
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.8"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
|
||||
dependencies = [
|
||||
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas-src"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.5"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "0.4.30"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029"
|
||||
dependencies = [
|
||||
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "0.6.13"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "864d3e96a899863136fc6e99f3d7cae289dafe43bf2c5ac19b70df7210c0a145"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.0"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||
dependencies = [
|
||||
"getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_chacha 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"getrandom",
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.0"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||
dependencies = [
|
||||
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
dependencies = [
|
||||
"getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96977acbdd3a6576fb1d27391900035bf3863d4a16422973a409b488cf29ffb2"
|
||||
dependencies = [
|
||||
"rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
dependencies = [
|
||||
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "sourcefile"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "0.15.39"
|
||||
version = "1.0.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea297be220d52398dcc07ce15a209fce436d361735ac1db700cab3b6cdfb9f54"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.1.0"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.1.5"
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.48"
|
||||
version = "0.2.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06"
|
||||
dependencies = [
|
||||
"wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.48"
|
||||
version = "0.2.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca"
|
||||
dependencies = [
|
||||
"bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"bumpalo",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.48"
|
||||
version = "0.2.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01"
|
||||
dependencies = [
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.48"
|
||||
version = "0.2.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.48"
|
||||
version = "0.2.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-webidl"
|
||||
version = "0.2.48"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2"
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.25"
|
||||
version = "0.3.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb"
|
||||
dependencies = [
|
||||
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weedle"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_distr 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-example-mnist"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"yarnn 0.1.0",
|
||||
"yarnn-model-mnist 0.1.0",
|
||||
"yarnn-native-blas 0.1.0",
|
||||
"mnist",
|
||||
"yarnn",
|
||||
"yarnn-model-mnist",
|
||||
"yarnn-native-blas",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-example-mnist-wasm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"yarnn 0.1.0",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
"yarnn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-example-vgg16-demo"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"yarnn 0.1.0",
|
||||
"yarnn-model-vgg16 0.1.0",
|
||||
"yarnn",
|
||||
"yarnn-model-vgg16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-model-mnist"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"yarnn 0.1.0",
|
||||
"yarnn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-model-vgg16"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"yarnn 0.1.0",
|
||||
"yarnn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarnn-native-blas"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"yarnn 0.1.0",
|
||||
"blas",
|
||||
"cblas",
|
||||
"openblas-src",
|
||||
"yarnn",
|
||||
]
|
||||
|
||||
[metadata]
|
||||
"checksum autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "0e49efa51329a5fd37e7c79db4621af617cd4e3e5bc224939808d076077077bf"
|
||||
"checksum backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)" = "18b50f5258d1a9ad8396d2d345827875de4261b158124d4c819d9b351454fae5"
|
||||
"checksum backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)" = "5b3a000b9c543553af61bc01cbfc403b04b5caa9e421033866f2e98061eb3e61"
|
||||
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
|
||||
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
|
||||
"checksum bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2cd43d82f27d68911e6ee11ee791fb248f138f5d69424dc02e098d4f152b0b05"
|
||||
"checksum byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5"
|
||||
"checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101"
|
||||
"checksum cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
|
||||
"checksum cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
"checksum cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "39f75544d7bbaf57560d2168f28fd649ff9c76153874db88bdbdfd839b1a7e7d"
|
||||
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
|
||||
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
|
||||
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
|
||||
"checksum getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "e65cce4e5084b14874c4e7097f38cab54f47ee554f9194673456ea379dcc4c55"
|
||||
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
|
||||
"checksum js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "da3ea71161651a4cd97d999b2da139109c537b15ab33abc8ae4ead38deac8a03"
|
||||
"checksum lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bc5729f27f159ddd61f4df6228e827e86643d4d3e7c32183cb30a1c08f604a14"
|
||||
"checksum libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)" = "3262021842bf00fe07dbd6cf34ff25c99d7a7ebef8deea84db72be3ea3bb0aff"
|
||||
"checksum log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "c275b6ad54070ac2d665eef9197db647b32239c9d244bfb6f041a766d00da5b3"
|
||||
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
|
||||
"checksum mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "25f19bfda80095b4294000bbb50506f028149ed0ddb7fabf46ebb673b91626bc"
|
||||
"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
|
||||
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
|
||||
"checksum num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32"
|
||||
"checksum openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
|
||||
"checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b"
|
||||
"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
|
||||
"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
|
||||
"checksum rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d47eab0e83d9693d40f825f86948aa16eff6750ead4bdffc4ab95b8b3a7f052c"
|
||||
"checksum rand_chacha 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e193067942ef6f485a349a113329140d0ab9e2168ce92274499bb0e9a4190d9d"
|
||||
"checksum rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "615e683324e75af5d43d8f7a39ffe3ee4a9dc42c5c701167a71dc59c3a493aca"
|
||||
"checksum rand_distr 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c37e54d811c5a51195156444e298a98ba57df6dca1e511f34d4791993883b921"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "a7f4dccf6f4891ebcc0c39f9b6eb1a83b9bf5d747cb439ec6fba4f3b977038af"
|
||||
"checksum sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4bf77cb82ba8453b42b6ae1d692e4cdc92f9a47beaf89a847c8be83f4e328ad3"
|
||||
"checksum spin 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44363f6f51401c34e7be73db0db371c04705d35efbe9f7d6082e03a921a32c55"
|
||||
"checksum syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)" = "b4d960b829a55e56db167e861ddb43602c003c7be0bee1d345021703fac2fb7c"
|
||||
"checksum synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "02353edf96d6e4dc81aea2d8490a7e9db177bf8acb0e951c24940bf866cb313f"
|
||||
"checksum unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1967f4cdfc355b37fd76d2a954fb2ed3871034eb4f26d60537d88795cfc332a9"
|
||||
"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
|
||||
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
|
||||
"checksum wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "4de97fa1806bb1a99904216f6ac5e0c050dc4f8c676dc98775047c38e5c01b55"
|
||||
"checksum wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "5d82c170ef9f5b2c63ad4460dfcee93f3ec04a9a36a4cc20bc973c39e59ab8e3"
|
||||
"checksum wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "f07d50f74bf7a738304f6b8157f4a581e1512cd9e9cdb5baad8c31bbe8ffd81d"
|
||||
"checksum wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "95cf8fe77e45ba5f91bc8f3da0c3aa5d464b3d8ed85d84f4d4c7cc106436b1d7"
|
||||
"checksum wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c2d4d4756b2e46d3a5422e06277d02e4d3e1d62d138b76a4c681e925743623"
|
||||
"checksum wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "24e47859b4eba3d3b9a5c2c299f9d6f8d0b613671315f6f0c5c7f835e524b36a"
|
||||
"checksum web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "86d515d2f713d3a6ab198031d2181b7540f8e319e4637ec2d4a41a208335ef29"
|
||||
"checksum weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3bb43f70885151e629e2a19ce9e50bd730fd436cfd4b666894c9ce4de9141164"
|
||||
|
@ -1,10 +1,10 @@
|
||||
use yarnn::prelude::*;
|
||||
use yarnn::native::{Native, NativeTensor};
|
||||
use yarnn_model_mnist::*;
|
||||
use yarnn::losses::CrossEntropyLoss;
|
||||
use yarnn::optimizers::Adam;
|
||||
use yarnn_native_blas::NativeBlas;
|
||||
use mnist::{Mnist, MnistBuilder};
|
||||
use yarnn::losses::CrossEntropyLoss;
|
||||
use yarnn::native::{Native, NativeTensor};
|
||||
use yarnn::optimizers::Adam;
|
||||
use yarnn::prelude::*;
|
||||
use yarnn_model_mnist::*;
|
||||
use yarnn_native_blas::NativeBlas;
|
||||
|
||||
fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -> f32 {
|
||||
let mut vec = vec![0.0; pred.shape().size()];
|
||||
@ -14,7 +14,7 @@ fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -
|
||||
let mut total = 0;
|
||||
|
||||
for (x, &y) in vec.chunks(10).zip(targets.iter()) {
|
||||
let x = &x[0 .. 10];
|
||||
let x = &x[0..10];
|
||||
|
||||
let mut max = 0;
|
||||
let mut max_value = 0.0;
|
||||
@ -33,7 +33,7 @@ fn calc_accuracy<N, B: Backend<N>>(back: &B, pred: &B::Tensor, targets: &[u8]) -
|
||||
total += 1;
|
||||
}
|
||||
|
||||
(positives as f32) / (total as f32)
|
||||
(positives as f32) / (total as f32)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
@ -41,7 +41,7 @@ fn main() {
|
||||
|
||||
let backend: NativeBlas<f32, Native<_>> = Default::default();
|
||||
let optimizer = Adam::default();
|
||||
|
||||
|
||||
// let mut model = MnistDenseModel::new(28, 28, 1);
|
||||
let mut model = MnistConvModel::new(28, 28, 1);
|
||||
model.init(&backend);
|
||||
@ -53,7 +53,13 @@ fn main() {
|
||||
|
||||
let loss = CrossEntropyLoss::new();
|
||||
|
||||
let Mnist { trn_img, trn_lbl, tst_img, tst_lbl, .. } = MnistBuilder::new()
|
||||
let Mnist {
|
||||
trn_img,
|
||||
trn_lbl,
|
||||
tst_img,
|
||||
tst_lbl,
|
||||
..
|
||||
} = MnistBuilder::new()
|
||||
.base_path("./datasets/mnist")
|
||||
.label_format_digit()
|
||||
.finalize();
|
||||
@ -81,14 +87,14 @@ fn main() {
|
||||
|
||||
backend.load_tensor_u8(&mut targets0, &tmp[..]);
|
||||
|
||||
for epoch in 1 ..= 4 {
|
||||
for epoch in 1..=4 {
|
||||
println!("epoch {}", epoch);
|
||||
|
||||
for step in 0 .. (60000 / BATCH_SIZE) {
|
||||
for step in 0..(60000 / BATCH_SIZE) {
|
||||
let offset = step * BATCH_SIZE;
|
||||
let mut tmp = [0u8; 10 * BATCH_SIZE];
|
||||
|
||||
let inputs_slice = &trn_img[offset * 784 .. (offset + BATCH_SIZE) * 784 ];
|
||||
let inputs_slice = &trn_img[offset * 784..(offset + BATCH_SIZE) * 784];
|
||||
let targets_slice = &trn_lbl[offset..offset + BATCH_SIZE];
|
||||
|
||||
backend.load_tensor_u8(&mut inputs, inputs_slice);
|
||||
@ -102,13 +108,16 @@ fn main() {
|
||||
|
||||
model.forward(&backend, &inputs, &mut train_ctx);
|
||||
loss.derivative(&backend, &mut deltas, train_ctx.outputs(), &targets);
|
||||
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
|
||||
model.backward(&backend, &deltas, &inputs, &mut train_ctx);
|
||||
model.calc_gradients(&backend, &deltas, &inputs, &mut train_ctx);
|
||||
model.optimize(&backend, &optimizer);
|
||||
}
|
||||
|
||||
model.forward(&backend, &inputs0, &mut test_ctx);
|
||||
|
||||
println!("Accuracy {}", calc_accuracy(&backend, test_ctx.outputs(), targets0_slice));
|
||||
println!(
|
||||
"Accuracy {}",
|
||||
calc_accuracy(&backend, test_ctx.outputs(), targets0_slice)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ use yarnn::native::Native;
|
||||
use yarnn::optimizers::Adam;
|
||||
use yarnn_model_vgg16::Vgg16Model;
|
||||
|
||||
|
||||
fn main() {
|
||||
let vgg16: Vgg16Model<f32, Native<_>, Adam<_, _>> = Vgg16Model::new(224, 224, 3);
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
#![feature(trait_alias)]
|
||||
|
||||
pub use self::dense::MnistDenseModel;
|
||||
pub use self::conv::MnistConvModel;
|
||||
pub use self::dense::MnistDenseModel;
|
||||
|
||||
mod dense {
|
||||
use yarnn::model;
|
||||
use yarnn::layers::*;
|
||||
use yarnn::model;
|
||||
|
||||
model! {
|
||||
MnistDenseModel (h: u32, w: u32, _c: u32) {
|
||||
@ -25,11 +25,10 @@ mod dense {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
mod conv {
|
||||
use yarnn::model;
|
||||
use yarnn::layers::*;
|
||||
|
||||
use yarnn::model;
|
||||
|
||||
model! {
|
||||
MnistConvModel (h: u32, w: u32, c: u32) {
|
||||
input_shape: (c, h, w),
|
||||
|
@ -1,3 +1,3 @@
|
||||
fn main() {
|
||||
println!("cargo:rustc-link-lib=dylib=cblas");
|
||||
}
|
||||
}
|
||||
|
@ -1,110 +1,147 @@
|
||||
#[allow(dead_code)]
|
||||
fn img_to_col_get_pixel(img: &[f32], img_rows: usize, img_cols: usize,
|
||||
mut row: isize, mut col: isize, channel: usize,
|
||||
pad_row: usize, pad_col: usize) -> f32
|
||||
{
|
||||
fn img_to_col_get_pixel(
|
||||
img: &[f32],
|
||||
img_rows: usize,
|
||||
img_cols: usize,
|
||||
mut row: isize,
|
||||
mut col: isize,
|
||||
channel: usize,
|
||||
pad_row: usize,
|
||||
pad_col: usize,
|
||||
) -> f32 {
|
||||
row -= pad_row as isize;
|
||||
col -= pad_col as isize;
|
||||
|
||||
if row < 0 || col < 0 ||
|
||||
row >= img_rows as isize ||
|
||||
col >= img_cols as isize { return 0.0 }
|
||||
|
||||
|
||||
if row < 0 || col < 0 || row >= img_rows as isize || col >= img_cols as isize {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
img[(channel * img_rows + row as usize) * img_cols + col as usize]
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn img_to_col(col: &mut [f32], img: &[f32], channels: usize,
|
||||
k_rows: usize, k_cols: usize,
|
||||
img_rows: usize, img_cols: usize,
|
||||
s_row: usize, s_col: usize,
|
||||
pad_row: usize, pad_col: usize)
|
||||
{
|
||||
pub fn img_to_col(
|
||||
col: &mut [f32],
|
||||
img: &[f32],
|
||||
channels: usize,
|
||||
k_rows: usize,
|
||||
k_cols: usize,
|
||||
img_rows: usize,
|
||||
img_cols: usize,
|
||||
s_row: usize,
|
||||
s_col: usize,
|
||||
pad_row: usize,
|
||||
pad_col: usize,
|
||||
) {
|
||||
let col_rows = (img_rows + 2 * pad_row - k_rows) / s_row + 1;
|
||||
let col_cols = (img_cols + 2 * pad_col - k_cols) / s_col + 1;
|
||||
|
||||
|
||||
let k_size = k_rows * k_cols;
|
||||
let channels_col = channels * k_size;
|
||||
|
||||
|
||||
let out_size = col_rows * col_cols;
|
||||
let col_size = channels_col * out_size;
|
||||
let col_s = &mut col[0..col_size];
|
||||
|
||||
|
||||
for ch in 0..channels_col {
|
||||
let offset_ch = ch / k_rows / k_cols;
|
||||
let offset_row = (ch / k_rows) % k_cols;
|
||||
let offset_col = ch % k_rows;
|
||||
|
||||
|
||||
for row in 0..col_rows {
|
||||
for col in 0..col_cols {
|
||||
let img_row = row * s_row + offset_row;
|
||||
let img_col = col * s_col + offset_col;
|
||||
|
||||
|
||||
let index_row = row * col_cols + col;
|
||||
let index_col = offset_row * k_rows + offset_col;
|
||||
let index = offset_ch * (k_size * out_size) + index_row * k_size + index_col;
|
||||
|
||||
|
||||
col_s[index] = img_to_col_get_pixel(
|
||||
img, img_rows, img_cols,
|
||||
img_row as isize, img_col as isize,
|
||||
offset_ch, pad_row, pad_col
|
||||
img,
|
||||
img_rows,
|
||||
img_cols,
|
||||
img_row as isize,
|
||||
img_col as isize,
|
||||
offset_ch,
|
||||
pad_row,
|
||||
pad_col,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn col_to_img_add_pixel(img: &mut [f32], img_rows: usize, img_cols: usize,
|
||||
mut row: isize, mut col: isize, channel: usize,
|
||||
pad_row: usize, pad_col: usize, val: f32) {
|
||||
|
||||
fn col_to_img_add_pixel(
|
||||
img: &mut [f32],
|
||||
img_rows: usize,
|
||||
img_cols: usize,
|
||||
mut row: isize,
|
||||
mut col: isize,
|
||||
channel: usize,
|
||||
pad_row: usize,
|
||||
pad_col: usize,
|
||||
val: f32,
|
||||
) {
|
||||
row -= pad_row as isize;
|
||||
col -= pad_col as isize;
|
||||
|
||||
if row < 0 || col < 0 ||
|
||||
row >= img_rows as isize ||
|
||||
col >= img_cols as isize { return; }
|
||||
|
||||
|
||||
if row < 0 || col < 0 || row >= img_rows as isize || col >= img_cols as isize {
|
||||
return;
|
||||
}
|
||||
|
||||
img[(channel * img_rows + row as usize) * img_cols + col as usize] += val;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn col_to_img(img: &mut [f32], col: &[f32], channels: usize,
|
||||
k_rows: usize, k_cols: usize,
|
||||
img_rows: usize, img_cols: usize,
|
||||
s_row: usize, s_col: usize,
|
||||
pad_row: usize, pad_col: usize) {
|
||||
|
||||
pub fn col_to_img(
|
||||
img: &mut [f32],
|
||||
col: &[f32],
|
||||
channels: usize,
|
||||
k_rows: usize,
|
||||
k_cols: usize,
|
||||
img_rows: usize,
|
||||
img_cols: usize,
|
||||
s_row: usize,
|
||||
s_col: usize,
|
||||
pad_row: usize,
|
||||
pad_col: usize,
|
||||
) {
|
||||
let col_rows = (img_rows + 2 * pad_row - k_rows) / s_row + 1;
|
||||
let col_cols = (img_cols + 2 * pad_col - k_cols) / s_col + 1;
|
||||
|
||||
|
||||
let k_size = k_rows * k_cols;
|
||||
let channels_col = channels * k_size;
|
||||
|
||||
|
||||
let out_size = col_rows * col_cols;
|
||||
let col_size = channels_col * out_size;
|
||||
|
||||
|
||||
let col_s = &col[0..col_size];
|
||||
|
||||
|
||||
for ch in 0..channels_col {
|
||||
let offset_ch = ch / k_rows / k_cols;
|
||||
let offset_row = (ch / k_rows) % k_cols;
|
||||
let offset_col = ch % k_rows;
|
||||
|
||||
|
||||
for row in 0..col_rows {
|
||||
for col in 0..col_cols {
|
||||
let img_row = row * s_row + offset_row;
|
||||
let img_col = col * s_col + offset_col;
|
||||
|
||||
|
||||
let index_row = row * col_cols + col;
|
||||
let index_col = offset_row * k_rows + offset_col;
|
||||
let index = offset_ch * (k_size * out_size) + index_row * k_size + index_col;
|
||||
|
||||
|
||||
col_to_img_add_pixel(
|
||||
img, img_rows, img_cols,
|
||||
img_row as isize, img_col as isize,
|
||||
offset_ch, pad_row, pad_col,
|
||||
col_s[index]
|
||||
img,
|
||||
img_rows,
|
||||
img_cols,
|
||||
img_row as isize,
|
||||
img_col as isize,
|
||||
offset_ch,
|
||||
pad_row,
|
||||
pad_col,
|
||||
col_s[index],
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -118,77 +155,48 @@ mod tests {
|
||||
#[test]
|
||||
fn test_img_to_col() {
|
||||
let img: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0,
|
||||
13.0, 14.0, 15.0, 16.0,
|
||||
|
||||
1.5, 2.5, 3.5, 4.5,
|
||||
5.5, 6.5, 7.5, 8.5,
|
||||
9.5, 10.5, 11.5, 12.5,
|
||||
13.5, 14.5, 15.5, 16.5,
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5,
|
||||
];
|
||||
|
||||
|
||||
let mut col = vec![0.0; 72];
|
||||
|
||||
img_to_col(&mut col, img, 2, 3, 3, 4, 4, 1, 1, 0, 0);
|
||||
|
||||
|
||||
let tmp: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0,
|
||||
2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0, 11.0, 12.0,
|
||||
5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0,
|
||||
6.0, 7.0, 8.0, 10.0, 11.0, 12.0, 14.0, 15.0, 16.0,
|
||||
1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5,
|
||||
2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5,
|
||||
5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
|
||||
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
|
||||
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0,
|
||||
11.0, 12.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0, 6.0, 7.0, 8.0, 10.0,
|
||||
11.0, 12.0, 14.0, 15.0, 16.0, 1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 2.5, 3.5,
|
||||
4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
|
||||
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
col.as_slice(),
|
||||
tmp
|
||||
);
|
||||
|
||||
assert_eq!(col.as_slice(), tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_col_to_img() {
|
||||
let col: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0,
|
||||
2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0, 11.0, 12.0,
|
||||
5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0,
|
||||
6.0, 7.0, 8.0, 10.0, 11.0, 12.0, 14.0, 15.0, 16.0,
|
||||
1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5,
|
||||
2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5,
|
||||
5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
|
||||
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
|
||||
1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0,
|
||||
11.0, 12.0, 5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0, 6.0, 7.0, 8.0, 10.0,
|
||||
11.0, 12.0, 14.0, 15.0, 16.0, 1.5, 2.5, 3.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 2.5, 3.5,
|
||||
4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 5.5, 6.5, 7.5, 9.5, 10.5, 11.5, 13.5, 14.5, 15.5,
|
||||
6.5, 7.5, 8.5, 10.5, 11.5, 12.5, 14.5, 15.5, 16.5,
|
||||
];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
col_to_img(y, col, 2, 3, 3, 4, 4, 1, 1, 0, 0);
|
||||
|
||||
let tmp: &[f32] = &[
|
||||
1.0, 4.0, 6.0, 4.0,
|
||||
10.0, 24.0, 28.0, 16.0,
|
||||
18.0, 40.0, 44.0, 24.0,
|
||||
13.0, 28.0, 30.0, 16.0,
|
||||
|
||||
1.5, 5.0, 7.0, 4.5,
|
||||
11.0, 26.0, 30.0, 17.0,
|
||||
19.0, 42.0, 46.0, 25.0,
|
||||
13.5, 29.0, 31.0, 16.5,
|
||||
1.0, 4.0, 6.0, 4.0, 10.0, 24.0, 28.0, 16.0, 18.0, 40.0, 44.0, 24.0, 13.0, 28.0, 30.0,
|
||||
16.0, 1.5, 5.0, 7.0, 4.5, 11.0, 26.0, 30.0, 17.0, 19.0, 42.0, 46.0, 25.0, 13.5, 29.0,
|
||||
31.0, 16.5,
|
||||
];
|
||||
|
||||
|
||||
assert_eq!(y, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,23 +1,25 @@
|
||||
mod img2col;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use yarnn::backend::*;
|
||||
use yarnn::native::*;
|
||||
use yarnn::tensor::*;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
extern crate openblas_src;
|
||||
|
||||
pub struct NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
pub struct NativeBlas<N, B>
|
||||
where
|
||||
N: NativeNumber,
|
||||
B: NativeBackend<N>,
|
||||
{
|
||||
inner: B,
|
||||
_m: PhantomData<fn(N)>
|
||||
_m: PhantomData<fn(N)>,
|
||||
}
|
||||
|
||||
impl<N, B> Default for NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
where
|
||||
N: NativeNumber,
|
||||
B: NativeBackend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -27,9 +29,10 @@ impl<N, B> Default for NativeBlas<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> NativeBackend<N> for NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
impl<N, B> NativeBackend<N> for NativeBlas<N, B>
|
||||
where
|
||||
N: NativeNumber,
|
||||
B: NativeBackend<N>,
|
||||
{
|
||||
#[inline]
|
||||
fn read_tensor<'a>(&self, t: &'a Self::Tensor) -> &'a [N] {
|
||||
@ -42,21 +45,23 @@ impl<N, B> NativeBackend<N> for NativeBlas<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
impl<N, B> NativeBlas<N, B>
|
||||
where
|
||||
N: NativeNumber,
|
||||
B: NativeBackend<N>,
|
||||
{
|
||||
pub fn new(inner: B) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_m: Default::default()
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> Backend<N> for NativeBlas<N, B>
|
||||
where N: NativeNumber,
|
||||
B: NativeBackend<N>
|
||||
impl<N, B> Backend<N> for NativeBlas<N, B>
|
||||
where
|
||||
N: NativeNumber,
|
||||
B: NativeBackend<N>,
|
||||
{
|
||||
type Tensor = B::Tensor;
|
||||
|
||||
@ -96,8 +101,9 @@ impl<N, B> Backend<N> for NativeBlas<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
where
|
||||
B: NativeBackend<f32>,
|
||||
{
|
||||
#[inline]
|
||||
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
@ -114,15 +120,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
let m = a_shape.get(0) as i32;
|
||||
let n = b_shape.get(1) as i32;
|
||||
let k = b_shape.get(0) as i32;
|
||||
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('N' as u8, 'N' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
self.read_tensor(b), n,
|
||||
self.read_tensor(a), k,
|
||||
0.0,
|
||||
self.write_tensor(dst), n);
|
||||
blas::sgemm(
|
||||
'N' as u8,
|
||||
'N' as u8,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
1.0,
|
||||
self.read_tensor(b),
|
||||
n,
|
||||
self.read_tensor(a),
|
||||
k,
|
||||
0.0,
|
||||
self.write_tensor(dst),
|
||||
n,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,15 +155,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
let m = a_shape.get(0) as i32;
|
||||
let n = b_shape.get(0) as i32;
|
||||
let k = b_shape.get(1) as i32;
|
||||
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('T' as u8, 'N' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
self.read_tensor(b), k,
|
||||
self.read_tensor(a), k,
|
||||
0.0,
|
||||
self.write_tensor(dst), n);
|
||||
blas::sgemm(
|
||||
'T' as u8,
|
||||
'N' as u8,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
1.0,
|
||||
self.read_tensor(b),
|
||||
k,
|
||||
self.read_tensor(a),
|
||||
k,
|
||||
0.0,
|
||||
self.write_tensor(dst),
|
||||
n,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,15 +190,23 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
let m = a_shape.get(1) as i32;
|
||||
let n = b_shape.get(1) as i32;
|
||||
let k = b_shape.get(0) as i32;
|
||||
|
||||
|
||||
unsafe {
|
||||
blas::sgemm('N' as u8, 'T' as u8,
|
||||
n, m, k,
|
||||
1.0,
|
||||
self.read_tensor(b), n,
|
||||
self.read_tensor(a), m,
|
||||
0.0,
|
||||
self.write_tensor(dst), n);
|
||||
blas::sgemm(
|
||||
'N' as u8,
|
||||
'T' as u8,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
1.0,
|
||||
self.read_tensor(b),
|
||||
n,
|
||||
self.read_tensor(a),
|
||||
m,
|
||||
0.0,
|
||||
self.write_tensor(dst),
|
||||
n,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,8 +216,9 @@ impl<B> BackendGemm<f32> for NativeBlas<f32, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
where
|
||||
B: NativeBackend<f32>,
|
||||
{
|
||||
#[inline]
|
||||
fn axpy(&self, dst: &mut Self::Tensor, scale: f32, x: &Self::Tensor) {
|
||||
@ -202,31 +233,26 @@ impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
|
||||
self.read_tensor(x),
|
||||
1,
|
||||
self.write_tensor(dst),
|
||||
1
|
||||
1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> BackendScale<f32> for NativeBlas<f32, B>
|
||||
where B: NativeBackend<f32>
|
||||
impl<B> BackendScale<f32> for NativeBlas<f32, B>
|
||||
where
|
||||
B: NativeBackend<f32>,
|
||||
{
|
||||
#[inline]
|
||||
fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
|
||||
let dst_size = dst.shape().size();
|
||||
|
||||
unsafe {
|
||||
blas::sscal(
|
||||
dst_size as i32,
|
||||
scale,
|
||||
self.write_tensor(dst),
|
||||
1
|
||||
);
|
||||
blas::sscal(dst_size as i32, scale, self.write_tensor(dst), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendSigmoid<f32>> BackendSigmoid<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
@ -265,7 +291,13 @@ impl<B: NativeBackend<f32> + BackendBias<f32>> BackendBias<f32> for NativeBlas<f
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendMse<f32>> BackendMse<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: f32) {
|
||||
fn scaled_square_diff(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
a: &Self::Tensor,
|
||||
b: &Self::Tensor,
|
||||
scale: f32,
|
||||
) {
|
||||
self.inner.scaled_square_diff(dst, a, b, scale)
|
||||
}
|
||||
|
||||
@ -303,7 +335,6 @@ impl<B: NativeBackend<f32> + BackendMul<f32>> BackendMul<f32> for NativeBlas<f32
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendCopy<f32>> BackendCopy<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
@ -318,10 +349,16 @@ impl<B: NativeBackend<f32> + BackendMaximum<f32>> BackendMaximum<f32> for Native
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendAdam<f32>> BackendAdam<f32> for NativeBlas<f32, B> {
|
||||
#[inline]
|
||||
fn adam_p(&self, dst: &mut Self::Tensor, lr: f32, moms: &Self::Tensor, vels: &Self::Tensor, eps: f32) {
|
||||
fn adam_p(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
lr: f32,
|
||||
moms: &Self::Tensor,
|
||||
vels: &Self::Tensor,
|
||||
eps: f32,
|
||||
) {
|
||||
self.inner.adam_p(dst, lr, moms, vels, eps)
|
||||
}
|
||||
}
|
||||
@ -337,17 +374,35 @@ impl<B: NativeBackend<f32> + BackendConv2d<f32>> BackendConv2d<f32> for NativeBl
|
||||
type Context = ();
|
||||
|
||||
#[inline]
|
||||
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn conv2d_forward(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
w: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
self.inner.conv2d_forward(y, x, w, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
self.inner.conv2d_backward_input(dx, dy, w, conv_info)
|
||||
fn conv2d_backward_input(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
w: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
self.inner.conv2d_backward_input(dx, dy, w, conv_info)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn conv2d_backward_filter(
|
||||
&self,
|
||||
dw: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
self.inner.conv2d_backward_filter(dw, x, dy, conv_info)
|
||||
}
|
||||
}
|
||||
@ -359,7 +414,13 @@ impl<B: NativeBackend<f32> + BackendMaxPool2d<f32>> BackendMaxPool2d<f32> for Na
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn max_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
self.inner.max_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
@ -371,14 +432,28 @@ impl<B: NativeBackend<f32> + BackendAvgPool2d<f32>> BackendAvgPool2d<f32> for Na
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn avg_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
self.inner.avg_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: NativeBackend<f32> + BackendPaddingCopy2d<f32>> BackendPaddingCopy2d<f32> for NativeBlas<f32, B> {
|
||||
impl<B: NativeBackend<f32> + BackendPaddingCopy2d<f32>> BackendPaddingCopy2d<f32>
|
||||
for NativeBlas<f32, B>
|
||||
{
|
||||
#[inline]
|
||||
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
|
||||
fn copy_with_padding2d(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
y_paddings: (u32, u32),
|
||||
x_paddings: (u32, u32),
|
||||
) {
|
||||
self.inner.copy_with_padding2d(y, x, y_paddings, x_paddings)
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
use crate::tensor::{Tensor};
|
||||
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
pub trait Backend<N> {
|
||||
type Tensor: Tensor<N>;
|
||||
@ -13,22 +12,22 @@ pub trait Backend<N> {
|
||||
fn print_tensor(&self, t: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: Backend<N>> Backend<N> for &'a T {
|
||||
impl<'a, N, T: Backend<N>> Backend<N> for &'a T {
|
||||
type Tensor = T::Tensor;
|
||||
|
||||
#[inline]
|
||||
fn store_tensor_f32(&self, dst: &Self::Tensor, slice: &mut [f32]) {
|
||||
(**self).store_tensor_f32(dst, slice)
|
||||
(**self).store_tensor_f32(dst, slice)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn load_tensor_u8(&self, dst: &mut Self::Tensor, slice: &[u8]) {
|
||||
(**self).load_tensor_u8(dst, slice)
|
||||
(**self).load_tensor_u8(dst, slice)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn load_tensor_f32(&self, dst: &mut Self::Tensor, slice: &[f32]) {
|
||||
(**self).load_tensor_f32(dst, slice)
|
||||
(**self).load_tensor_f32(dst, slice)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -59,7 +58,7 @@ pub trait BackendGemm<N>: Backend<N> {
|
||||
fn matmul_tt(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendGemm<N>> BackendGemm<N> for &'a T {
|
||||
impl<'a, N, T: BackendGemm<N>> BackendGemm<N> for &'a T {
|
||||
#[inline]
|
||||
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
(**self).matmul(dst, a, b)
|
||||
@ -86,7 +85,7 @@ pub trait BackendBias<N>: Backend<N> {
|
||||
fn bias_grad(&self, bias: &mut Self::Tensor, inputs: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendBias<N>> BackendBias<N> for &'a T {
|
||||
impl<'a, N, T: BackendBias<N>> BackendBias<N> for &'a T {
|
||||
#[inline]
|
||||
fn bias_add(&self, dst: &mut Self::Tensor, bias: &Self::Tensor) {
|
||||
(**self).bias_add(dst, bias)
|
||||
@ -103,7 +102,7 @@ pub trait BackendSigmoid<N>: Backend<N> {
|
||||
fn sigmoid_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
|
||||
impl<'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
|
||||
#[inline]
|
||||
fn sigmoid(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
(**self).sigmoid(dst, data)
|
||||
@ -115,13 +114,12 @@ impl <'a, N, T: BackendSigmoid<N>> BackendSigmoid<N> for &'a T {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait BackendReLu<N>: Backend<N> {
|
||||
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor);
|
||||
fn relu_grad(&self, dst: &mut Self::Tensor, z: &Self::Tensor, d: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendReLu<N>> BackendReLu<N> for &'a T {
|
||||
impl<'a, N, T: BackendReLu<N>> BackendReLu<N> for &'a T {
|
||||
#[inline]
|
||||
fn relu(&self, dst: &mut Self::Tensor, data: &Self::Tensor) {
|
||||
(**self).relu(dst, data)
|
||||
@ -137,7 +135,7 @@ pub trait BackendScale<N>: Backend<N> {
|
||||
fn scale(&self, dst: &mut Self::Tensor, scale: N);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
|
||||
impl<'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
|
||||
#[inline]
|
||||
fn scale(&self, dst: &mut Self::Tensor, scale: N) {
|
||||
(**self).scale(dst, scale)
|
||||
@ -145,13 +143,25 @@ impl <'a, N, T: BackendScale<N>> BackendScale<N> for &'a T {
|
||||
}
|
||||
|
||||
pub trait BackendMse<N>: Backend<N> {
|
||||
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N);
|
||||
fn scaled_square_diff(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
a: &Self::Tensor,
|
||||
b: &Self::Tensor,
|
||||
scale: N,
|
||||
);
|
||||
fn scaled_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendMse<N>> BackendMse<N> for &'a T {
|
||||
impl<'a, N, T: BackendMse<N>> BackendMse<N> for &'a T {
|
||||
#[inline]
|
||||
fn scaled_square_diff(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor, scale: N) {
|
||||
fn scaled_square_diff(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
a: &Self::Tensor,
|
||||
b: &Self::Tensor,
|
||||
scale: N,
|
||||
) {
|
||||
(**self).scaled_square_diff(dst, a, b, scale)
|
||||
}
|
||||
|
||||
@ -165,7 +175,7 @@ pub trait BackendAxpy<N>: Backend<N> {
|
||||
fn axpy(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendAxpy<N>> BackendAxpy<N> for &'a T {
|
||||
impl<'a, N, T: BackendAxpy<N>> BackendAxpy<N> for &'a T {
|
||||
#[inline]
|
||||
fn axpy(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor) {
|
||||
(**self).axpy(dst, scale, a)
|
||||
@ -176,7 +186,7 @@ pub trait BackendAxpys<N>: Backend<N> {
|
||||
fn axpys(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendAxpys<N>> BackendAxpys<N> for &'a T {
|
||||
impl<'a, N, T: BackendAxpys<N>> BackendAxpys<N> for &'a T {
|
||||
#[inline]
|
||||
fn axpys(&self, dst: &mut Self::Tensor, scale: N, a: &Self::Tensor) {
|
||||
(**self).axpys(dst, scale, a)
|
||||
@ -187,7 +197,7 @@ pub trait BackendAdd<N>: Backend<N> {
|
||||
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendAdd<N>> BackendAdd<N> for &'a T {
|
||||
impl<'a, N, T: BackendAdd<N>> BackendAdd<N> for &'a T {
|
||||
#[inline]
|
||||
fn add(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
(**self).add(dst, a)
|
||||
@ -198,7 +208,7 @@ pub trait BackendSub<N>: Backend<N> {
|
||||
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendSub<N>> BackendSub<N> for &'a T {
|
||||
impl<'a, N, T: BackendSub<N>> BackendSub<N> for &'a T {
|
||||
#[inline]
|
||||
fn sub(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
|
||||
(**self).sub(dst, a, b)
|
||||
@ -209,7 +219,7 @@ pub trait BackendMul<N>: Backend<N> {
|
||||
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendMul<N>> BackendMul<N> for &'a T {
|
||||
impl<'a, N, T: BackendMul<N>> BackendMul<N> for &'a T {
|
||||
#[inline]
|
||||
fn mul(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
(**self).mul(dst, a)
|
||||
@ -220,7 +230,7 @@ pub trait BackendCopy<N>: Backend<N> {
|
||||
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendCopy<N>> BackendCopy<N> for &'a T {
|
||||
impl<'a, N, T: BackendCopy<N>> BackendCopy<N> for &'a T {
|
||||
#[inline]
|
||||
fn copy(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
(**self).copy(dst, a)
|
||||
@ -231,20 +241,36 @@ pub trait BackendMaximum<N>: Backend<N> {
|
||||
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendMaximum<N>> BackendMaximum<N> for &'a T {
|
||||
impl<'a, N, T: BackendMaximum<N>> BackendMaximum<N> for &'a T {
|
||||
#[inline]
|
||||
fn maximum(&self, dst: &mut Self::Tensor, a: &Self::Tensor) {
|
||||
(**self).maximum(dst, a)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait BackendAdam<N>: BackendScale<N> + BackendAxpy<N> + BackendAxpys<N> + BackendMaximum<N> {
|
||||
fn adam_p(&self, dst: &mut Self::Tensor, lr: N, moms: &Self::Tensor, vels: &Self::Tensor, eps: N);
|
||||
pub trait BackendAdam<N>:
|
||||
BackendScale<N> + BackendAxpy<N> + BackendAxpys<N> + BackendMaximum<N>
|
||||
{
|
||||
fn adam_p(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
lr: N,
|
||||
moms: &Self::Tensor,
|
||||
vels: &Self::Tensor,
|
||||
eps: N,
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendAdam<N>> BackendAdam<N> for &'a T {
|
||||
impl<'a, N, T: BackendAdam<N>> BackendAdam<N> for &'a T {
|
||||
#[inline]
|
||||
fn adam_p(&self, dst: &mut Self::Tensor, lr: N, moms: &Self::Tensor, vels: &Self::Tensor, eps: N) {
|
||||
fn adam_p(
|
||||
&self,
|
||||
dst: &mut Self::Tensor,
|
||||
lr: N,
|
||||
moms: &Self::Tensor,
|
||||
vels: &Self::Tensor,
|
||||
eps: N,
|
||||
) {
|
||||
(**self).adam_p(dst, lr, moms, vels, eps)
|
||||
}
|
||||
}
|
||||
@ -253,7 +279,7 @@ pub trait BackendSoftmax<N>: BackendCopy<N> {
|
||||
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendSoftmax<N>> BackendSoftmax<N> for &'a T {
|
||||
impl<'a, N, T: BackendSoftmax<N>> BackendSoftmax<N> for &'a T {
|
||||
#[inline]
|
||||
fn softmax(&self, y: &mut Self::Tensor, x: &Self::Tensor) {
|
||||
(**self).softmax(y, x)
|
||||
@ -269,7 +295,7 @@ pub enum PaddingKind {
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct Conv2dInfo {
|
||||
pub padding: PaddingKind,
|
||||
pub padding: PaddingKind,
|
||||
pub strides: (u32, u32),
|
||||
pub kernel: (u32, u32),
|
||||
}
|
||||
@ -277,26 +303,62 @@ pub struct Conv2dInfo {
|
||||
pub trait BackendConv2d<N>: Backend<N> {
|
||||
type Context;
|
||||
|
||||
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, filter: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, filter: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn conv2d_forward(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
filter: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
);
|
||||
fn conv2d_backward_input(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
filter: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
);
|
||||
fn conv2d_backward_filter(
|
||||
&self,
|
||||
dw: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendConv2d<N>> BackendConv2d<N> for &'a T {
|
||||
impl<'a, N, T: BackendConv2d<N>> BackendConv2d<N> for &'a T {
|
||||
type Context = ();
|
||||
|
||||
|
||||
#[inline]
|
||||
fn conv2d_forward(&self, y: &mut Self::Tensor, x: &Self::Tensor, filters: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn conv2d_forward(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
filters: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
(**self).conv2d_forward(y, x, filters, conv_info)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, filters: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn conv2d_backward_input(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
filters: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
(**self).conv2d_backward_input(dx, dy, filters, conv_info)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn conv2d_backward_filter(
|
||||
&self,
|
||||
dw: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
(**self).conv2d_forward(dw, x, dy, conv_info)
|
||||
}
|
||||
}
|
||||
@ -309,45 +371,81 @@ pub enum PoolingKind {
|
||||
|
||||
pub trait BackendMaxPool2d<N>: Backend<N> {
|
||||
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn max_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendMaxPool2d<N>> BackendMaxPool2d<N> for &'a T {
|
||||
impl<'a, N, T: BackendMaxPool2d<N>> BackendMaxPool2d<N> for &'a T {
|
||||
#[inline]
|
||||
fn max_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
(**self).max_pool2d(y, x, conv_info)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn max_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn max_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
(**self).max_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait BackendAvgPool2d<N>: Backend<N> {
|
||||
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo);
|
||||
fn avg_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendAvgPool2d<N>> BackendAvgPool2d<N> for &'a T {
|
||||
impl<'a, N, T: BackendAvgPool2d<N>> BackendAvgPool2d<N> for &'a T {
|
||||
#[inline]
|
||||
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
(**self).avg_pool2d(y, x, conv_info)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn avg_pool2d_backprop(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
|
||||
fn avg_pool2d_backprop(
|
||||
&self,
|
||||
dx: &mut Self::Tensor,
|
||||
dy: &Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
conv_info: &Conv2dInfo,
|
||||
) {
|
||||
(**self).avg_pool2d_backprop(dx, dy, x, conv_info)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait BackendPaddingCopy2d<N>: Backend<N> {
|
||||
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32));
|
||||
fn copy_with_padding2d(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
y_paddings: (u32, u32),
|
||||
x_paddings: (u32, u32),
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, T: BackendPaddingCopy2d<N>> BackendPaddingCopy2d<N> for &'a T {
|
||||
impl<'a, N, T: BackendPaddingCopy2d<N>> BackendPaddingCopy2d<N> for &'a T {
|
||||
#[inline]
|
||||
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
|
||||
fn copy_with_padding2d(
|
||||
&self,
|
||||
y: &mut Self::Tensor,
|
||||
x: &Self::Tensor,
|
||||
y_paddings: (u32, u32),
|
||||
x_paddings: (u32, u32),
|
||||
) {
|
||||
(**self).copy_with_padding2d(y, x, y_paddings, x_paddings)
|
||||
}
|
||||
}
|
||||
|
@ -4,10 +4,10 @@ use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
// use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub trait Layer<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context: LayerContext<N, B>;
|
||||
|
||||
@ -15,7 +15,7 @@ pub trait Layer<N, B, O>
|
||||
fn param_count(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn init(&mut self, _backend: &B) {}
|
||||
|
||||
@ -25,41 +25,58 @@ pub trait Layer<N, B, O>
|
||||
fn output_shape(&self) -> TensorShape {
|
||||
self.input_shape()
|
||||
}
|
||||
|
||||
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context);
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context);
|
||||
|
||||
#[inline]
|
||||
fn calc_gradients(&mut self, _backend: &B, _dy: &B::Tensor, _x: &B::Tensor, _ctx: &mut Self::Context) {}
|
||||
fn calc_gradients(
|
||||
&mut self,
|
||||
_backend: &B,
|
||||
_dy: &B::Tensor,
|
||||
_x: &B::Tensor,
|
||||
_ctx: &mut Self::Context,
|
||||
) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, _backend: &B, _optimizer: &O) {}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
|
||||
writeln!(f, "{}{} -> {}[{}] -> {}", "".repeat(padding), self.input_shape(), self.name(), self.param_count(), self.output_shape())?;
|
||||
writeln!(
|
||||
f,
|
||||
"{}{} -> {}[{}] -> {}",
|
||||
"".repeat(padding),
|
||||
self.input_shape(),
|
||||
self.name(),
|
||||
self.param_count(),
|
||||
self.output_shape()
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LayerExt<N, B, O>: Layer<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config: Default;
|
||||
|
||||
fn create(input_shape: TensorShape, cfg: Self::Config) -> Self;
|
||||
|
||||
#[inline]
|
||||
fn add_layer<L: LayerExt<N, B, O>>(self, cfg: L::Config) -> crate::layers::Chain<N, B, O, Self, L>
|
||||
where Self: Sized
|
||||
fn add_layer<L: LayerExt<N, B, O>>(
|
||||
self,
|
||||
cfg: L::Config,
|
||||
) -> crate::layers::Chain<N, B, O, Self, L>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let shape = self.output_shape();
|
||||
|
||||
crate::layers::Chain::new(
|
||||
self,
|
||||
L::create(shape, cfg),
|
||||
)
|
||||
crate::layers::Chain::new(self, L::create(shape, cfg))
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,16 +85,17 @@ pub trait LayerContext<N, B: Backend<N>>: Default {
|
||||
fn deltas(&self) -> &B::Tensor;
|
||||
}
|
||||
|
||||
|
||||
pub struct DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct DefaultLayerContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
pub outputs: B::Tensor,
|
||||
pub deltas: B::Tensor,
|
||||
}
|
||||
|
||||
impl <N, B> Default for DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
impl<N, B> Default for DefaultLayerContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -87,8 +105,9 @@ impl <N, B> Default for DefaultLayerContext<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B> DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
impl<N, B> DefaultLayerContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
pub fn update_deltas_shape(&mut self, bs: u32, input_shape: &TensorShape) {
|
||||
let mut new_deltas_shape = TensorShape::new1d(bs);
|
||||
@ -110,8 +129,9 @@ impl <N, B> DefaultLayerContext<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
|
||||
where B: Backend<N>,
|
||||
impl<N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
#[inline]
|
||||
fn outputs(&self) -> &B::Tensor {
|
||||
@ -122,4 +142,4 @@ impl <N, B> LayerContext<N, B> for DefaultLayerContext<N, B>
|
||||
fn deltas(&self) -> &B::Tensor {
|
||||
&self.deltas
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::backend::{Backend, PaddingKind, BackendAvgPool2d, Conv2dInfo};
|
||||
use crate::backend::{Backend, BackendAvgPool2d, Conv2dInfo, PaddingKind};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
@ -19,24 +19,26 @@ impl Default for AvgPool2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AvgPool2d<N, B>
|
||||
where B: Backend<N>
|
||||
pub struct AvgPool2d<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
conv_info: Conv2dInfo,
|
||||
_m: PhantomData<fn(N, B)>
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
|
||||
where B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"AvgPool2d"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -51,13 +53,9 @@ impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
|
||||
let rows = (is[0] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
|
||||
let cols = (is[1] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
|
||||
|
||||
TensorShape::new3d(
|
||||
is[0],
|
||||
rows,
|
||||
cols,
|
||||
)
|
||||
TensorShape::new3d(is[0], rows, cols)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
@ -83,9 +81,10 @@ impl <N, B, O> Layer<N, B, O> for AvgPool2d<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
|
||||
where B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for AvgPool2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendAvgPool2d<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = AvgPool2dConfig;
|
||||
|
||||
|
@ -5,20 +5,22 @@ use crate::tensor::TensorShape;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct ChainContext<N, B, L, R>
|
||||
where B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
pub struct ChainContext<N, B, L, R>
|
||||
where
|
||||
B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
{
|
||||
left: L,
|
||||
left: L,
|
||||
right: R,
|
||||
_m: PhantomData<fn(N, B)>
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B, L, R> Default for ChainContext<N, B, L, R>
|
||||
where B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
impl<N, B, L, R> Default for ChainContext<N, B, L, R>
|
||||
where
|
||||
B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -30,9 +32,10 @@ impl<N, B, L, R> Default for ChainContext<N, B, L, R>
|
||||
}
|
||||
|
||||
impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
|
||||
where B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
where
|
||||
B: Backend<N>,
|
||||
L: LayerContext<N, B>,
|
||||
R: LayerContext<N, B>,
|
||||
{
|
||||
#[inline]
|
||||
fn outputs(&self) -> &B::Tensor {
|
||||
@ -43,24 +46,26 @@ impl<N, B, L, R> LayerContext<N, B> for ChainContext<N, B, L, R>
|
||||
fn deltas(&self) -> &B::Tensor {
|
||||
self.left.deltas()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
left: L,
|
||||
right: R,
|
||||
_m: PhantomData<fn(N, B, O)>
|
||||
}
|
||||
|
||||
impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
pub struct Chain<N, B, O, L, R>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
left: L,
|
||||
right: R,
|
||||
_m: PhantomData<fn(N, B, O)>,
|
||||
}
|
||||
|
||||
impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
pub fn new(left: L, right: R) -> Self {
|
||||
Self {
|
||||
@ -71,7 +76,7 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
}
|
||||
}
|
||||
|
||||
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
|
||||
// impl<N, B, O, L, R> core::fmt::Display for Chain<N, B, O, L, R>
|
||||
// where B: Backend<N>,
|
||||
// O: Optimizer<N, B>,
|
||||
// L: Layer<N, B, O>,
|
||||
@ -85,11 +90,12 @@ impl<N, B, O, L, R> Chain<N, B, O, L, R>
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
L: Layer<N, B, O>,
|
||||
R: Layer<N, B, O>,
|
||||
{
|
||||
type Context = ChainContext<N, B, L::Context, R::Context>;
|
||||
|
||||
@ -101,7 +107,7 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
self.left.param_count() + self.right.param_count()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn init(&mut self, backend: &B) {
|
||||
@ -122,19 +128,36 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.left.forward(backend, inputs, &mut ctx.left);
|
||||
self.right.forward(backend, ctx.left.outputs(), &mut ctx.right);
|
||||
self.right
|
||||
.forward(backend, ctx.left.outputs(), &mut ctx.right);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.right.backward(backend, deltas, ctx.left.outputs(), &mut ctx.right);
|
||||
self.left.backward(backend, ctx.right.deltas(), inputs, &mut ctx.left);
|
||||
fn backward(
|
||||
&mut self,
|
||||
backend: &B,
|
||||
deltas: &B::Tensor,
|
||||
inputs: &B::Tensor,
|
||||
ctx: &mut Self::Context,
|
||||
) {
|
||||
self.right
|
||||
.backward(backend, deltas, ctx.left.outputs(), &mut ctx.right);
|
||||
self.left
|
||||
.backward(backend, ctx.right.deltas(), inputs, &mut ctx.left);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn calc_gradients(&mut self, backend: &B, deltas: &B::Tensor, inputs: &B::Tensor, ctx: &mut Self::Context) {
|
||||
self.left.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
|
||||
self.right.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
|
||||
fn calc_gradients(
|
||||
&mut self,
|
||||
backend: &B,
|
||||
deltas: &B::Tensor,
|
||||
inputs: &B::Tensor,
|
||||
ctx: &mut Self::Context,
|
||||
) {
|
||||
self.left
|
||||
.calc_gradients(backend, ctx.right.deltas(), inputs, &mut ctx.left);
|
||||
self.right
|
||||
.calc_gradients(backend, deltas, ctx.left.outputs(), &mut ctx.right);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -146,7 +169,7 @@ impl<N, B, O, L, R> Layer<N, B, O> for Chain<N, B, O, L, R>
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter, padding: usize) -> core::fmt::Result {
|
||||
self.left.fmt(f, padding)?;
|
||||
self.right.fmt(f, padding)?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::params::Params;
|
||||
use crate::backend::{Backend, Conv2dInfo, PaddingKind, BackendBias, BackendConv2d, BackendScale};
|
||||
use crate::backend::{Backend, BackendBias, BackendConv2d, BackendScale, Conv2dInfo, PaddingKind};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::params::Params;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
pub struct Conv2dConfig {
|
||||
pub filters: u32,
|
||||
@ -24,9 +24,10 @@ impl Default for Conv2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Conv2d<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
pub struct Conv2d<N, B, O>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
units: u32,
|
||||
@ -36,9 +37,10 @@ pub struct Conv2d<N, B, O>
|
||||
biases: Params<N, B, O>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
where
|
||||
B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
@ -53,10 +55,13 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
} else {
|
||||
self.filters.params.shape().size()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.filters.init_random(backend, self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0));
|
||||
self.filters.init_random(
|
||||
backend,
|
||||
self.conv_info.kernel.0 * self.conv_info.kernel.1 + self.filters.params.shape().get(0),
|
||||
);
|
||||
|
||||
if self.use_biases {
|
||||
self.biases.init_zero(backend);
|
||||
@ -77,13 +82,9 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
|
||||
let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
|
||||
|
||||
TensorShape::new3d(
|
||||
self.units,
|
||||
rows,
|
||||
cols,
|
||||
)
|
||||
TensorShape::new3d(self.units, rows, cols)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
assert_eq!(x.shape().dims, 4);
|
||||
@ -101,14 +102,20 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
assert_eq!(dy.shape().dims, 4);
|
||||
|
||||
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
backend.conv2d_backward_input(&mut ctx.deltas, dy, &self.filters.params, &self.conv_info);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
|
||||
fn calc_gradients(
|
||||
&mut self,
|
||||
backend: &B,
|
||||
dy: &B::Tensor,
|
||||
x: &B::Tensor,
|
||||
_ctx: &mut Self::Context,
|
||||
) {
|
||||
assert_eq!(dy.shape().dims, 4);
|
||||
assert_eq!(x.shape().dims, 4);
|
||||
|
||||
@ -120,18 +127,24 @@ impl <N, B, O> Layer<N, B, O> for Conv2d<N, B, O>
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, backend: &B, optimizer: &O) {
|
||||
optimizer.update_params(backend, &mut self.filters.ctx, &mut self.filters.params, &mut self.filters.grads);
|
||||
optimizer.update_params(
|
||||
backend,
|
||||
&mut self.filters.ctx,
|
||||
&mut self.filters.params,
|
||||
&mut self.filters.grads,
|
||||
);
|
||||
|
||||
if self.use_biases {
|
||||
unimplemented!()
|
||||
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
|
||||
// optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &self.biases.grads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
|
||||
where B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
|
||||
where
|
||||
B: Backend<N> + BackendConv2d<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = Conv2dConfig;
|
||||
|
||||
@ -143,12 +156,12 @@ impl <N, B, O> LayerExt<N, B, O> for Conv2d<N, B, O>
|
||||
units: cfg.filters,
|
||||
conv_info: Conv2dInfo {
|
||||
kernel: cfg.kernel,
|
||||
padding: cfg.padding,
|
||||
padding: cfg.padding,
|
||||
strides: cfg.strides,
|
||||
},
|
||||
use_biases: cfg.biases,
|
||||
filters: Params::new((cfg.filters, cfg.kernel.0, cfg.kernel.1)),
|
||||
biases: Params::new((cfg.filters, )),
|
||||
biases: Params::new((cfg.filters,)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,30 +1,31 @@
|
||||
use crate::tensor::{TensorShape, Tensor};
|
||||
use crate::backend::{Backend, BackendCopy};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FlattenConfig;
|
||||
|
||||
pub struct Flatten<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct Flatten<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for Flatten<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Flatten"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -34,11 +35,11 @@ impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
|
||||
fn output_shape(&self) -> TensorShape {
|
||||
TensorShape::new1d(self.input_shape.size() as u32)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
|
||||
|
||||
backend.copy(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
@ -50,16 +51,17 @@ impl <N, B, O> Layer<N, B, O> for Flatten<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Flatten<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for Flatten<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = FlattenConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Flatten {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
_x: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::params::Params;
|
||||
use crate::backend::{Backend, BackendGemm, BackendBias, BackendScale};
|
||||
use crate::backend::{Backend, BackendBias, BackendGemm, BackendScale};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::params::Params;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
pub struct LinearConfig {
|
||||
pub units: u32,
|
||||
@ -18,9 +18,10 @@ impl Default for LinearConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Linear<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
pub struct Linear<N, B, O>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
inputs: u32,
|
||||
outputs: u32,
|
||||
@ -29,9 +30,10 @@ pub struct Linear<N, B, O>
|
||||
biases: Params<N, B, O>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
where
|
||||
B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
@ -46,10 +48,11 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
} else {
|
||||
self.weights.params.shape().size()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn init(&mut self, backend: &B) {
|
||||
self.weights.init_random(backend, self.inputs + self.outputs);
|
||||
self.weights
|
||||
.init_random(backend, self.inputs + self.outputs);
|
||||
if self.use_biases {
|
||||
self.biases.init_zero(backend);
|
||||
}
|
||||
@ -64,7 +67,7 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
fn output_shape(&self) -> TensorShape {
|
||||
TensorShape::new1d(self.outputs)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.output_shape());
|
||||
@ -79,11 +82,17 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape());
|
||||
|
||||
|
||||
backend.matmul_nt(&mut ctx.deltas, dy, &self.weights.params);
|
||||
}
|
||||
|
||||
fn calc_gradients(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, _ctx: &mut Self::Context) {
|
||||
fn calc_gradients(
|
||||
&mut self,
|
||||
backend: &B,
|
||||
dy: &B::Tensor,
|
||||
x: &B::Tensor,
|
||||
_ctx: &mut Self::Context,
|
||||
) {
|
||||
let prescaler = 1.0 / x.shape().get(0) as f32;
|
||||
|
||||
backend.matmul_tn(&mut self.weights.grads, x, dy);
|
||||
@ -97,17 +106,28 @@ impl <N, B, O> Layer<N, B, O> for Linear<N, B, O>
|
||||
|
||||
#[inline]
|
||||
fn optimize(&mut self, backend: &B, optimizer: &O) {
|
||||
optimizer.update_params(backend, &mut self.weights.ctx, &mut self.weights.params, &mut self.weights.grads);
|
||||
optimizer.update_params(
|
||||
backend,
|
||||
&mut self.weights.ctx,
|
||||
&mut self.weights.params,
|
||||
&mut self.weights.grads,
|
||||
);
|
||||
|
||||
if self.use_biases {
|
||||
optimizer.update_params(backend, &mut self.biases.ctx, &mut self.biases.params, &mut self.biases.grads);
|
||||
optimizer.update_params(
|
||||
backend,
|
||||
&mut self.biases.ctx,
|
||||
&mut self.biases.params,
|
||||
&mut self.biases.grads,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
|
||||
where B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
|
||||
where
|
||||
B: Backend<N> + BackendGemm<N> + BackendBias<N> + BackendScale<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = LinearConfig;
|
||||
|
||||
@ -121,7 +141,7 @@ impl <N, B, O> LayerExt<N, B, O> for Linear<N, B, O>
|
||||
outputs: cfg.units,
|
||||
use_biases: cfg.biases,
|
||||
weights: Params::new((inputs, cfg.units)),
|
||||
biases: Params::new((cfg.units, )),
|
||||
biases: Params::new((cfg.units,)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::tensor::{TensorShape, Tensor};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
|
||||
use crate::backend::{Backend, BackendMaxPool2d, Conv2dInfo, PaddingKind};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct MaxPool2dConfig {
|
||||
@ -18,24 +18,26 @@ impl Default for MaxPool2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MaxPool2d<N, B>
|
||||
where B: Backend<N>
|
||||
pub struct MaxPool2d<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
conv_info: Conv2dInfo,
|
||||
_m: PhantomData<fn(N, B)>
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
|
||||
where B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"MaxPool2d"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -44,26 +46,22 @@ impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
|
||||
#[inline]
|
||||
fn output_shape(&self) -> TensorShape {
|
||||
let is = self.input_shape.as_slice();
|
||||
|
||||
|
||||
// O = (W - K + 2P) / S + 1
|
||||
|
||||
let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
|
||||
let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
|
||||
|
||||
TensorShape::new3d(
|
||||
is[0],
|
||||
rows,
|
||||
cols,
|
||||
)
|
||||
TensorShape::new3d(is[0], rows, cols)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
|
||||
|
||||
backend.max_pool2d(&mut ctx.outputs, x, &self.conv_info)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(x.shape().get(0), &self.input_shape);
|
||||
@ -72,9 +70,10 @@ impl <N, B, O> Layer<N, B, O> for MaxPool2d<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
|
||||
where B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendMaxPool2d<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = MaxPool2dConfig;
|
||||
|
||||
@ -91,4 +90,4 @@ impl <N, B, O> LayerExt<N, B, O> for MaxPool2d<N, B>
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,21 +1,21 @@
|
||||
mod linear;
|
||||
mod sigmoid;
|
||||
mod chain;
|
||||
mod relu;
|
||||
mod softmax;
|
||||
mod avgpool2d;
|
||||
mod maxpool2d;
|
||||
mod zeropadding2d;
|
||||
mod chain;
|
||||
mod conv2d;
|
||||
mod flatten;
|
||||
mod linear;
|
||||
mod maxpool2d;
|
||||
mod relu;
|
||||
mod sigmoid;
|
||||
mod softmax;
|
||||
mod zeropadding2d;
|
||||
|
||||
pub use self::linear::*;
|
||||
pub use self::sigmoid::*;
|
||||
pub use self::chain::*;
|
||||
pub use self::relu::*;
|
||||
pub use self::softmax::*;
|
||||
pub use self::conv2d::*;
|
||||
pub use self::avgpool2d::*;
|
||||
pub use self::maxpool2d::*;
|
||||
pub use self::chain::*;
|
||||
pub use self::conv2d::*;
|
||||
pub use self::flatten::*;
|
||||
pub use self::zeropadding2d::*;
|
||||
pub use self::linear::*;
|
||||
pub use self::maxpool2d::*;
|
||||
pub use self::relu::*;
|
||||
pub use self::sigmoid::*;
|
||||
pub use self::softmax::*;
|
||||
pub use self::zeropadding2d::*;
|
||||
|
@ -1,59 +1,62 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendReLu};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ReLuConfig;
|
||||
|
||||
pub struct ReLu<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct ReLu<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for ReLu<N, B>
|
||||
where B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for ReLu<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ReLU"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
|
||||
backend.relu(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn backward(&mut self, backend: &B, dy: &B::Tensor, _: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_deltas_shape(dy.shape().get(0), &self.input_shape);
|
||||
|
||||
|
||||
backend.relu_grad(&mut ctx.deltas, &ctx.outputs, dy);
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for ReLu<N, B>
|
||||
where B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for ReLu<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendReLu<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = ReLuConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
ReLu {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
_x: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,23 +1,25 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendSigmoid};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SigmoidConfig;
|
||||
|
||||
pub struct Sigmoid<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct Sigmoid<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
|
||||
where B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
impl<N, B, O> Layer<N, B, O> for Sigmoid<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
@ -32,7 +34,7 @@ impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
|
||||
#[inline]
|
||||
fn forward(&self, backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &self.input_shape);
|
||||
|
||||
|
||||
backend.sigmoid(&mut ctx.outputs, x);
|
||||
}
|
||||
|
||||
@ -44,16 +46,17 @@ impl <N, B, O> Layer<N, B, O> for Sigmoid<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
|
||||
where B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for Sigmoid<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendSigmoid<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = SigmoidConfig;
|
||||
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Sigmoid {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
_x: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,29 +1,31 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendSoftmax};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SoftmaxConfig;
|
||||
|
||||
pub struct Softmax<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct Softmax<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
|
||||
where B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>
|
||||
{
|
||||
impl<N, B, O> Layer<N, B, O> for Softmax<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Softmax"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -44,16 +46,17 @@ impl <N, B, O> Layer<N, B, O> for Softmax<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for Softmax<N, B>
|
||||
where B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for Softmax<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendSoftmax<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = SoftmaxConfig;
|
||||
|
||||
fn create(input_shape: TensorShape, _cfg: Self::Config) -> Self {
|
||||
Softmax {
|
||||
input_shape,
|
||||
_x: Default::default()
|
||||
_x: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use crate::backend::{Backend, BackendCopy};
|
||||
use crate::layer::{Layer, LayerExt, DefaultLayerContext};
|
||||
use crate::layer::{DefaultLayerContext, Layer, LayerExt};
|
||||
use crate::optimizer::Optimizer;
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -9,24 +9,26 @@ pub struct ZeroPadding2dConfig {
|
||||
pub paddings: (u32, u32),
|
||||
}
|
||||
|
||||
pub struct ZeroPadding2d<N, B>
|
||||
where B: Backend<N>,
|
||||
pub struct ZeroPadding2d<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
input_shape: TensorShape,
|
||||
config: ZeroPadding2dConfig,
|
||||
_x: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = DefaultLayerContext<N, B>;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ZeroPadding2d"
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> TensorShape {
|
||||
self.input_shape.clone()
|
||||
@ -39,10 +41,10 @@ impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
|
||||
TensorShape::new3d(
|
||||
is[0],
|
||||
is[1] + self.config.paddings.0 * 2,
|
||||
is[2] + self.config.paddings.1 * 2
|
||||
is[2] + self.config.paddings.1 * 2,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
fn forward(&self, _backend: &B, x: &B::Tensor, ctx: &mut Self::Context) {
|
||||
ctx.update_outputs_shape(x.shape().get(0), &Layer::<N, B, O>::output_shape(self));
|
||||
@ -56,9 +58,10 @@ impl <N, B, O> Layer<N, B, O> for ZeroPadding2d<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
|
||||
where B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendCopy<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Config = ZeroPadding2dConfig;
|
||||
|
||||
@ -66,8 +69,7 @@ impl <N, B, O> LayerExt<N, B, O> for ZeroPadding2d<N, B>
|
||||
ZeroPadding2d {
|
||||
input_shape,
|
||||
config,
|
||||
_x: Default::default()
|
||||
_x: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![feature(specialization, trait_alias)]
|
||||
#![recursion_limit="128"]
|
||||
#![recursion_limit = "128"]
|
||||
|
||||
pub mod layer;
|
||||
pub mod layers;
|
||||
@ -13,15 +13,15 @@ pub mod native;
|
||||
pub mod loss;
|
||||
pub mod losses;
|
||||
|
||||
pub mod tensor;
|
||||
pub mod params;
|
||||
pub mod tensor;
|
||||
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
|
||||
pub mod prelude {
|
||||
pub use super::backend::*;
|
||||
pub use super::layer::*;
|
||||
pub use super::loss::*;
|
||||
pub use super::tensor::*;
|
||||
pub use super::layer::*;
|
||||
}
|
||||
}
|
||||
|
@ -3,4 +3,4 @@ use crate::backend::Backend;
|
||||
pub trait Loss<N, B: Backend<N>> {
|
||||
fn compute(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor);
|
||||
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor);
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,15 @@
|
||||
use crate::loss::Loss;
|
||||
use crate::backend::{Backend, BackendSub};
|
||||
use crate::loss::Loss;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
|
||||
pub struct CrossEntropyLoss<N, B> {
|
||||
_m: PhantomData<fn(N, B)>
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> CrossEntropyLoss<N, B> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_m: Default::default()
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -21,7 +19,7 @@ impl<N, B: Backend<N> + BackendSub<N>> Loss<N, B> for CrossEntropyLoss<N, B> {
|
||||
// TODO
|
||||
}
|
||||
|
||||
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
|
||||
fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
|
||||
backend.sub(dst, pred, target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
mod mse;
|
||||
mod cross_entropy;
|
||||
mod mse;
|
||||
|
||||
pub use self::cross_entropy::*;
|
||||
pub use self::mse::*;
|
||||
pub use self::cross_entropy::*;
|
@ -1,24 +1,24 @@
|
||||
use crate::loss::Loss;
|
||||
use crate::backend::{Backend, BackendMse};
|
||||
use crate::loss::Loss;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub struct MeanSquareErrorLoss<N, B> {
|
||||
_m: PhantomData<fn(N, B)>
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> MeanSquareErrorLoss<N, B> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_m: Default::default()
|
||||
_m: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
|
||||
where B: Backend<N> + BackendMse<N>
|
||||
impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
|
||||
where
|
||||
B: Backend<N> + BackendMse<N>,
|
||||
{
|
||||
fn compute(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
|
||||
let batch_size = pred.shape().get(0) as f32;
|
||||
@ -31,4 +31,4 @@ impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B>
|
||||
|
||||
backend.scaled_diff(dst, pred, target, backend.scalar_f32(1.0 / batch_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! sequential_type {
|
||||
(input_shape: ($($shape:tt)*), layers: { $($layers:tt)* }) => {
|
||||
@ -11,13 +10,13 @@ macro_rules! sequential_type_impl {
|
||||
($t:ty {$($tt:tt)*}) => ($t);
|
||||
|
||||
($t:ty {$($xx:tt)*}, $($tt:tt)*) => {
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$t, $crate::sequential_type_impl!($($tt)*)
|
||||
>
|
||||
};
|
||||
($t:ty) => ($t);
|
||||
($t:ty, $($tt:tt)*) => {
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$crate::layers::Chain<N, B, O,
|
||||
$t, $crate::sequential_type_impl!($($tt)*)
|
||||
>
|
||||
};
|
||||
@ -33,7 +32,7 @@ macro_rules! sequential {
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! sequential_impl {
|
||||
macro_rules! sequential_impl {
|
||||
($p:expr, $t:ty { $($name:ident : $val:expr),* }) => {{
|
||||
#[allow(unused_mut)]
|
||||
let mut params = <$t as $crate::layer::LayerExt<N, B, O>>::Config::default();
|
||||
@ -58,7 +57,7 @@ macro_rules! sequential_impl {
|
||||
);
|
||||
|
||||
let prev_shape = $crate::layer::Layer::<N, B, O>::output_shape(&layer);
|
||||
|
||||
|
||||
$crate::layers::Chain::new(
|
||||
layer, $crate::sequential_impl! { prev_shape, $($tt)* },
|
||||
)
|
||||
@ -119,7 +118,7 @@ macro_rules! model_impl {
|
||||
_m: core::marker::PhantomData<fn(N, B, O)>,
|
||||
}
|
||||
|
||||
impl<N, B, O> $name<N, B, O>
|
||||
impl<N, B, O> $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
@ -130,7 +129,7 @@ macro_rules! model_impl {
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use $crate::backend::PoolingKind::*;
|
||||
|
||||
|
||||
Self {
|
||||
inner: $crate::sequential!($($tt)*),
|
||||
_m: Default::default(),
|
||||
@ -138,7 +137,7 @@ macro_rules! model_impl {
|
||||
}
|
||||
}
|
||||
|
||||
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
// impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
// where B: $crate::backend::Backend<N> + $trait,
|
||||
// O: $crate::optimizer::Optimizer<N, B>
|
||||
// {
|
||||
@ -151,7 +150,7 @@ macro_rules! model_impl {
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
|
||||
impl<N, B, O> $crate::layer::Layer<N, B, O> for $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
@ -170,7 +169,7 @@ macro_rules! model_impl {
|
||||
#[inline]
|
||||
fn param_count(&self) -> usize {
|
||||
self.inner.param_count()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn input_shape(&self) -> $crate::tensor::TensorShape {
|
||||
@ -206,12 +205,12 @@ macro_rules! model_impl {
|
||||
writeln!(f, "{}{}[{}] {{", "", self.name(), self.param_count())?;
|
||||
self.inner.fmt(f, padding + 2)?;
|
||||
write!(f, "}}")?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
impl<N, B, O> core::fmt::Display for $name<N, B, O>
|
||||
where B: $crate::backend::Backend<N> + $trait,
|
||||
O: $crate::optimizer::Optimizer<N, B>
|
||||
{
|
||||
@ -227,7 +226,7 @@ macro_rules! model_impl {
|
||||
macro_rules! model {
|
||||
($name:ident ($($init:tt)*) { $($tt:tt)* }) => {
|
||||
mod tmp {
|
||||
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
|
||||
pub trait BackendDefault<N> = $crate::backend::BackendReLu<N>
|
||||
+ $crate::backend::BackendBias<N>
|
||||
+ $crate::backend::BackendScale<N>
|
||||
+ $crate::backend::BackendSigmoid<N>
|
||||
@ -242,4 +241,4 @@ macro_rules! model {
|
||||
($name:ident <$trait:path> ($($init:tt)*) { $($tt:tt)* }) => {
|
||||
$crate::model_impl!($name <$trait> ($($init)*) { $($tt)* });
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,24 @@
|
||||
#[allow(dead_code)]
|
||||
pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn valid_conv2d_3x3(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - 3) / s_row + 1;
|
||||
let y_cols = (x_cols - 3) / s_col + 1;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..9];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut sum = 0.0;
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[0];
|
||||
@ -19,13 +26,13 @@ pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
sum += x[(xi + 2) as usize] * w[2];
|
||||
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[3];
|
||||
sum += x[(xi + 1) as usize] * w[4];
|
||||
sum += x[(xi + 2) as usize] * w[5];
|
||||
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[6];
|
||||
sum += x[(xi + 1) as usize] * w[7];
|
||||
sum += x[(xi + 2) as usize] * w[8];
|
||||
@ -35,19 +42,24 @@ pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn full_xcorr2d_3x3(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_cols = (x_cols - 1) * s_col + 3;
|
||||
let y_rows = (x_rows - 1) * s_row + 3;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..9];
|
||||
|
||||
|
||||
for x_y in 0..x_rows {
|
||||
for x_x in 0..x_cols {
|
||||
let mut yi = s_row * x_y * y_cols + s_col * x_x;
|
||||
@ -57,12 +69,12 @@ pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
y[(yi + 1) as usize] += z * w[7];
|
||||
y[(yi + 2) as usize] += z * w[6];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[5];
|
||||
y[(yi + 1) as usize] += z * w[4];
|
||||
y[(yi + 2) as usize] += z * w[3];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[2];
|
||||
y[(yi + 1) as usize] += z * w[1];
|
||||
y[(yi + 2) as usize] += z * w[0];
|
||||
@ -70,24 +82,32 @@ pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d_forward_3x3(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn conv2d_forward_3x3(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - 3) / s_row + 1;
|
||||
let y_cols = (x_cols - 3) / s_col + 1;
|
||||
|
||||
|
||||
let x_img_size = x_rows * x_cols;
|
||||
let y_img_size = y_rows * y_cols;
|
||||
let w_img_size = 9;
|
||||
|
||||
|
||||
let x_batch_size = x_channels * x_img_size;
|
||||
let y_batch_size = y_channels * y_img_size;
|
||||
|
||||
|
||||
let y = &mut y[0..(bs * y_batch_size) as usize];
|
||||
let x = &x[0..(bs * x_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for x_ch in 0..x_channels {
|
||||
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
|
||||
@ -96,50 +116,56 @@ pub fn conv2d_forward_3x3(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
for y_ch in 0..y_channels {
|
||||
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
|
||||
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
|
||||
valid_conv2d_3x3(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn conv2d_backward_3x3(dx: &mut [f32], dy: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
}
|
||||
|
||||
pub fn conv2d_backward_3x3(
|
||||
dx: &mut [f32],
|
||||
dy: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let x_cols = (y_cols - 1) * s_col + 3;
|
||||
let x_rows = (y_rows - 1) * s_row + 3;
|
||||
|
||||
|
||||
let dx_img_size = x_rows * x_cols;
|
||||
let dy_img_size = y_rows * y_cols;
|
||||
let w_img_size = 9;
|
||||
|
||||
|
||||
let dx_batch_size = x_channels * dx_img_size;
|
||||
let dy_batch_size = y_channels * dy_img_size;
|
||||
|
||||
|
||||
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
|
||||
let dy = &dy[0..(bs * dy_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for y_ch in 0..y_channels {
|
||||
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
|
||||
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
|
||||
|
||||
|
||||
for x_ch in 0..x_channels {
|
||||
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
|
||||
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
|
||||
full_xcorr2d_3x3(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,24 @@
|
||||
#[allow(dead_code)]
|
||||
pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn valid_conv2d_5x5(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - 5) / s_row + 1;
|
||||
let y_cols = (x_cols - 5) / s_col + 1;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..25];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut sum = 0.0;
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[0];
|
||||
@ -20,28 +27,28 @@ pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
sum += x[(xi + 3) as usize] * w[3];
|
||||
sum += x[(xi + 4) as usize] * w[4];
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[5];
|
||||
sum += x[(xi + 1) as usize] * w[6];
|
||||
sum += x[(xi + 2) as usize] * w[7];
|
||||
sum += x[(xi + 3) as usize] * w[8];
|
||||
sum += x[(xi + 4) as usize] * w[9];
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[10];
|
||||
sum += x[(xi + 1) as usize] * w[11];
|
||||
sum += x[(xi + 2) as usize] * w[12];
|
||||
sum += x[(xi + 3) as usize] * w[13];
|
||||
sum += x[(xi + 4) as usize] * w[14];
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[15];
|
||||
sum += x[(xi + 1) as usize] * w[16];
|
||||
sum += x[(xi + 2) as usize] * w[17];
|
||||
sum += x[(xi + 3) as usize] * w[18];
|
||||
sum += x[(xi + 4) as usize] * w[19];
|
||||
xi += x_cols;
|
||||
|
||||
|
||||
sum += x[(xi + 0) as usize] * w[20];
|
||||
sum += x[(xi + 1) as usize] * w[21];
|
||||
sum += x[(xi + 2) as usize] * w[22];
|
||||
@ -53,19 +60,24 @@ pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn full_xcorr2d_5x5(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_cols = (x_cols - 1) * s_col + 5;
|
||||
let y_rows = (x_rows - 1) * s_row + 5;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..25];
|
||||
|
||||
|
||||
for x_y in 0..x_rows {
|
||||
for x_x in 0..x_cols {
|
||||
let mut yi = s_row * x_y * y_cols + s_col * x_x;
|
||||
@ -77,28 +89,28 @@ pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
y[(yi + 3) as usize] += z * w[21];
|
||||
y[(yi + 4) as usize] += z * w[20];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[19];
|
||||
y[(yi + 1) as usize] += z * w[18];
|
||||
y[(yi + 2) as usize] += z * w[17];
|
||||
y[(yi + 3) as usize] += z * w[16];
|
||||
y[(yi + 4) as usize] += z * w[15];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[14];
|
||||
y[(yi + 1) as usize] += z * w[13];
|
||||
y[(yi + 2) as usize] += z * w[12];
|
||||
y[(yi + 3) as usize] += z * w[11];
|
||||
y[(yi + 4) as usize] += z * w[10];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[9];
|
||||
y[(yi + 1) as usize] += z * w[8];
|
||||
y[(yi + 2) as usize] += z * w[7];
|
||||
y[(yi + 3) as usize] += z * w[6];
|
||||
y[(yi + 4) as usize] += z * w[5];
|
||||
yi += y_cols;
|
||||
|
||||
|
||||
y[(yi + 0) as usize] += z * w[4];
|
||||
y[(yi + 1) as usize] += z * w[3];
|
||||
y[(yi + 2) as usize] += z * w[2];
|
||||
@ -108,24 +120,32 @@ pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d_forward_5x5(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn conv2d_forward_5x5(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - 5) / s_row + 1;
|
||||
let y_cols = (x_cols - 5) / s_col + 1;
|
||||
|
||||
|
||||
let x_img_size = x_rows * x_cols;
|
||||
let y_img_size = y_rows * y_cols;
|
||||
let w_img_size = 25;
|
||||
|
||||
|
||||
let x_batch_size = x_channels * x_img_size;
|
||||
let y_batch_size = y_channels * y_img_size;
|
||||
|
||||
|
||||
let y = &mut y[0..(bs * y_batch_size) as usize];
|
||||
let x = &x[0..(bs * x_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for x_ch in 0..x_channels {
|
||||
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
|
||||
@ -134,50 +154,56 @@ pub fn conv2d_forward_5x5(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
for y_ch in 0..y_channels {
|
||||
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
|
||||
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
|
||||
valid_conv2d_5x5(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn conv2d_backward_5x5(dx: &mut [f32], dy: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
}
|
||||
|
||||
pub fn conv2d_backward_5x5(
|
||||
dx: &mut [f32],
|
||||
dy: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let x_cols = (y_cols - 1) * s_col + 5;
|
||||
let x_rows = (y_rows - 1) * s_row + 5;
|
||||
|
||||
|
||||
let dx_img_size = x_rows * x_cols;
|
||||
let dy_img_size = y_rows * y_cols;
|
||||
let w_img_size = 25;
|
||||
|
||||
|
||||
let dx_batch_size = x_channels * dx_img_size;
|
||||
let dy_batch_size = y_channels * dy_img_size;
|
||||
|
||||
|
||||
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
|
||||
let dy = &dy[0..(bs * dy_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for y_ch in 0..y_channels {
|
||||
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
|
||||
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
|
||||
|
||||
|
||||
for x_ch in 0..x_channels {
|
||||
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
|
||||
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
|
||||
full_xcorr2d_5x5(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,30 +5,36 @@ pub use self::kernel_3x3::*;
|
||||
pub use self::kernel_5x5::*;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn valid_conv2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - w_rows) / s_row + 1;
|
||||
let y_cols = (x_cols - w_cols) / s_col + 1;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..(w_rows * w_cols) as usize];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut wi = 0;
|
||||
|
||||
|
||||
let mut sum = 0.0;
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
sum += x[(xi + w_x) as usize] * w[(wi + w_x) as usize];
|
||||
}
|
||||
|
||||
|
||||
xi += x_cols;
|
||||
wi += w_cols;
|
||||
}
|
||||
@ -39,30 +45,36 @@ pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn valid_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn valid_xcorr2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - w_rows) / s_row + 1;
|
||||
let y_cols = (x_cols - w_cols) / s_col + 1;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..(w_rows * w_cols) as usize];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
let mut wi = w_rows * w_cols - 1;
|
||||
|
||||
|
||||
let mut sum = 0.0;
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
sum += x[(xi + w_x) as usize] * w[(wi - w_x) as usize];
|
||||
}
|
||||
|
||||
|
||||
xi += x_cols;
|
||||
wi -= w_cols;
|
||||
}
|
||||
@ -73,29 +85,36 @@ pub fn valid_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn full_conv2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_cols = (x_cols - 1) * s_col + w_cols;
|
||||
let y_rows = (x_rows - 1) * s_row + w_rows;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..(w_rows * w_cols) as usize];
|
||||
|
||||
|
||||
for x_y in 0..x_rows {
|
||||
for x_x in 0..x_cols {
|
||||
let mut yi = s_row * x_y * y_cols + s_col * x_x;
|
||||
let mut wi = 0;
|
||||
let z = alpha * x[(x_y * x_cols + x_x) as usize];
|
||||
|
||||
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
y[(yi + w_x) as usize] += z * w[(wi + w_x) as usize];
|
||||
}
|
||||
|
||||
|
||||
yi += y_cols;
|
||||
wi += w_cols;
|
||||
}
|
||||
@ -104,29 +123,36 @@ pub fn full_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn full_xcorr2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
alpha: f32,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_cols = (x_cols - 1) * s_col + w_cols;
|
||||
let y_rows = (x_rows - 1) * s_row + w_rows;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let w = &w[0..(w_rows * w_cols) as usize];
|
||||
|
||||
|
||||
for x_y in 0..x_rows {
|
||||
for x_x in 0..x_cols {
|
||||
let mut yi = s_row * x_y * y_cols + s_col * x_x;
|
||||
let mut wi = w_rows * w_cols - 1;
|
||||
let z = alpha * x[(x_y * x_cols + x_x) as usize];
|
||||
|
||||
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
y[(yi + w_x) as usize] += z * w[(wi - w_x) as usize];
|
||||
}
|
||||
|
||||
|
||||
yi += y_cols;
|
||||
wi -= w_cols;
|
||||
}
|
||||
@ -134,27 +160,34 @@ pub fn full_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn conv2d_forward(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn conv2d_forward(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y_rows = (x_rows - w_rows) / s_row + 1;
|
||||
let y_cols = (x_cols - w_cols) / s_col + 1;
|
||||
|
||||
|
||||
let x_img_size = x_rows * x_cols;
|
||||
let y_img_size = y_rows * y_cols;
|
||||
let w_img_size = w_rows * w_cols;
|
||||
|
||||
|
||||
let x_batch_size = x_channels * x_img_size;
|
||||
let y_batch_size = y_channels * y_img_size;
|
||||
|
||||
|
||||
let y = &mut y[0..(bs * y_batch_size) as usize];
|
||||
let x = &x[0..(bs * x_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for x_ch in 0..x_channels {
|
||||
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
|
||||
@ -163,89 +196,111 @@ pub fn conv2d_forward(y: &mut [f32], x: &[f32], w: &[f32],
|
||||
for y_ch in 0..y_channels {
|
||||
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
|
||||
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
valid_conv2d(y_img, x_img, w, 1.0, x_rows, x_cols, w_rows, w_cols, s_row, s_col);
|
||||
|
||||
valid_conv2d(
|
||||
y_img, x_img, w, 1.0, x_rows, x_cols, w_rows, w_cols, s_row, s_col,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d_backward(dx: &mut [f32], dy: &[f32], w: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
}
|
||||
|
||||
pub fn conv2d_backward(
|
||||
dx: &mut [f32],
|
||||
dy: &[f32],
|
||||
w: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let x_cols = (y_cols - 1) * s_col + w_cols;
|
||||
let x_rows = (y_rows - 1) * s_row + w_rows;
|
||||
|
||||
|
||||
let dx_img_size = x_rows * x_cols;
|
||||
let dy_img_size = y_rows * y_cols;
|
||||
let w_img_size = w_rows * w_cols;
|
||||
|
||||
|
||||
let dx_batch_size = x_channels * dx_img_size;
|
||||
let dy_batch_size = y_channels * dy_img_size;
|
||||
|
||||
|
||||
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
|
||||
let dy = &dy[0..(bs * dy_batch_size) as usize];
|
||||
let w = &w[0..(y_channels * w_img_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for y_ch in 0..y_channels {
|
||||
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
|
||||
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
|
||||
|
||||
|
||||
for x_ch in 0..x_channels {
|
||||
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
|
||||
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
|
||||
|
||||
|
||||
let w_offset = (y_ch * w_img_size) as usize;
|
||||
let w = &w[w_offset..w_offset + w_img_size as usize];
|
||||
|
||||
full_xcorr2d(dx_img, dy_img, w, 1.0, y_rows, y_cols, w_rows, w_cols, s_row, s_col);
|
||||
|
||||
full_xcorr2d(
|
||||
dx_img, dy_img, w, 1.0, y_rows, y_cols, w_rows, w_cols, s_row, s_col,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d_grads(dw: &mut [f32], x: &[f32], dy: &[f32],
|
||||
bs: isize, x_channels: isize, y_channels: isize,
|
||||
x_rows: isize, x_cols: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn conv2d_grads(
|
||||
dw: &mut [f32],
|
||||
x: &[f32],
|
||||
dy: &[f32],
|
||||
bs: isize,
|
||||
x_channels: isize,
|
||||
y_channels: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let w_cols = x_cols - y_cols + 1;
|
||||
let w_rows = x_rows - y_rows + 1;
|
||||
|
||||
|
||||
let x_img_size = x_rows * x_cols;
|
||||
let dy_img_size = y_rows * y_cols;
|
||||
let dw_img_size = w_rows * w_cols;
|
||||
|
||||
|
||||
let x_batch_size = x_channels * x_img_size;
|
||||
let dy_batch_size = y_channels * dy_img_size;
|
||||
|
||||
|
||||
let dw = &mut dw[0..(y_channels * dw_img_size) as usize];
|
||||
let dy = &dy[0..(bs * dy_batch_size) as usize];
|
||||
let x = &x[0..(bs * x_batch_size) as usize];
|
||||
|
||||
|
||||
for bi in 0..bs {
|
||||
for x_ch in 0..x_channels {
|
||||
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
|
||||
let x_img = &x[x_offset..x_offset + x_img_size as usize];
|
||||
|
||||
|
||||
for y_ch in 0..y_channels {
|
||||
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
|
||||
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
|
||||
|
||||
|
||||
let dw_offset = (y_ch * dw_img_size) as usize;
|
||||
let dw = &mut dw[dw_offset..dw_offset + dw_img_size as usize];
|
||||
|
||||
valid_conv2d(dw, x_img, dy_img, 1.0, x_rows, x_cols, y_rows, y_cols, s_row, s_col);
|
||||
|
||||
valid_conv2d(
|
||||
dw, x_img, dy_img, 1.0, x_rows, x_cols, y_rows, y_cols, s_row, s_col,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -256,198 +311,136 @@ mod test {
|
||||
#[test]
|
||||
fn test_valid_conv2d() {
|
||||
let x: &[f32] = &[
|
||||
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
|
||||
1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7,
|
||||
2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7,
|
||||
3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7,
|
||||
4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
|
||||
5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7,
|
||||
6.0, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7,
|
||||
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.0,
|
||||
2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 4.0, 4.1,
|
||||
4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0, 6.1, 6.2,
|
||||
6.3, 6.4, 6.5, 6.6, 6.7,
|
||||
];
|
||||
|
||||
let w: &[f32] = &[
|
||||
0.1, 0.2, 0.3,
|
||||
0.4, 0.5, 0.6,
|
||||
0.7, 0.8, 0.9,
|
||||
];
|
||||
|
||||
|
||||
let w: &[f32] = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
valid_conv2d(y, x, w, 1.0, 7, 8, 3, 3, 1, 1);
|
||||
|
||||
assert_eq!(y, &[
|
||||
6.81, 7.26, 7.71, 8.16, 8.610001, 9.06,
|
||||
11.31, 11.76, 12.210001, 12.66, 13.11, 13.56,
|
||||
15.809999, 16.26, 16.710001, 17.160002, 17.61, 18.06,
|
||||
20.31, 20.76, 21.210001, 21.66, 22.11, 22.559998,
|
||||
24.81, 25.26, 25.710001, 26.160002, 26.61, 27.060001
|
||||
])
|
||||
assert_eq!(
|
||||
y,
|
||||
&[
|
||||
6.81, 7.26, 7.71, 8.16, 8.610001, 9.06, 11.31, 11.76, 12.210001, 12.66, 13.11,
|
||||
13.56, 15.809999, 16.26, 16.710001, 17.160002, 17.61, 18.06, 20.31, 20.76,
|
||||
21.210001, 21.66, 22.11, 22.559998, 24.81, 25.26, 25.710001, 26.160002, 26.61,
|
||||
27.060001
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_xcorr2d() {
|
||||
let x: &[f32] = &[
|
||||
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
|
||||
1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7,
|
||||
2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7,
|
||||
3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7,
|
||||
4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
|
||||
5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7,
|
||||
6.0, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7,
|
||||
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.0,
|
||||
2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 4.0, 4.1,
|
||||
4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0, 6.1, 6.2,
|
||||
6.3, 6.4, 6.5, 6.6, 6.7,
|
||||
];
|
||||
|
||||
let w: &[f32] = &[
|
||||
0.1, 0.2, 0.3,
|
||||
0.4, 0.5, 0.6,
|
||||
0.7, 0.8, 0.9,
|
||||
];
|
||||
|
||||
|
||||
let w: &[f32] = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
valid_xcorr2d(y, x, w, 1.0, 7, 8, 3, 3, 1, 1);
|
||||
|
||||
assert_eq!(y, &[
|
||||
3.0900004, 3.54, 3.9900002, 4.44, 4.8900003, 5.34,
|
||||
7.59, 8.04, 8.490001, 8.940001, 9.389999, 9.84,
|
||||
12.089999, 12.540001, 12.989999, 13.44, 13.889998, 14.340001,
|
||||
16.59, 17.039999, 17.490002, 17.939999, 18.39, 18.84,
|
||||
21.09, 21.539999, 21.99, 22.44, 22.89, 23.34
|
||||
])
|
||||
assert_eq!(
|
||||
y,
|
||||
&[
|
||||
3.0900004, 3.54, 3.9900002, 4.44, 4.8900003, 5.34, 7.59, 8.04, 8.490001, 8.940001,
|
||||
9.389999, 9.84, 12.089999, 12.540001, 12.989999, 13.44, 13.889998, 14.340001,
|
||||
16.59, 17.039999, 17.490002, 17.939999, 18.39, 18.84, 21.09, 21.539999, 21.99,
|
||||
22.44, 22.89, 23.34
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_conv2d() {
|
||||
let x: &[f32] = &[
|
||||
1.0, 2.0,
|
||||
3.0, 4.0,
|
||||
];
|
||||
|
||||
let w: &[f32] = &[
|
||||
1.0, 2.0, 1.0,
|
||||
3.0, 2.0, 1.0,
|
||||
1.0, 4.0, 3.0,
|
||||
];
|
||||
|
||||
let x: &[f32] = &[1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
let w: &[f32] = &[1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 1.0, 4.0, 3.0];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
full_conv2d(y, x, w, 1.0, 2, 2, 3, 3, 1, 1);
|
||||
|
||||
assert_eq!(y, &[
|
||||
1.0, 4.0, 5.0, 2.0,
|
||||
6.0, 18.0, 16.0, 6.0,
|
||||
10.0, 24.0, 22.0, 10.0,
|
||||
3.0, 16.0, 25.0, 12.0,
|
||||
])
|
||||
assert_eq!(
|
||||
y,
|
||||
&[
|
||||
1.0, 4.0, 5.0, 2.0, 6.0, 18.0, 16.0, 6.0, 10.0, 24.0, 22.0, 10.0, 3.0, 16.0, 25.0,
|
||||
12.0,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_xcorr2d() {
|
||||
let x: &[f32] = &[
|
||||
1.0, 2.0,
|
||||
3.0, 4.0,
|
||||
];
|
||||
|
||||
let w: &[f32] = &[
|
||||
1.0, 2.0, 1.0,
|
||||
3.0, 2.0, 1.0,
|
||||
1.0, 4.0, 3.0,
|
||||
];
|
||||
|
||||
let x: &[f32] = &[1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
let w: &[f32] = &[1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 1.0, 4.0, 3.0];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
full_xcorr2d(y, x, w, 1.0, 2, 2, 3, 3, 1, 1);
|
||||
|
||||
assert_eq!(y, &[
|
||||
3.0, 10.0, 9.0, 2.0,
|
||||
10.0, 28.0, 26.0, 10.0,
|
||||
4.0, 14.0, 22.0, 14.0,
|
||||
3.0, 10.0, 11.0, 4.0,
|
||||
])
|
||||
assert_eq!(
|
||||
y,
|
||||
&[
|
||||
3.0, 10.0, 9.0, 2.0, 10.0, 28.0, 26.0, 10.0, 4.0, 14.0, 22.0, 14.0, 3.0, 10.0,
|
||||
11.0, 4.0,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_grads() {
|
||||
let x: &[f32] = &[
|
||||
0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
let true_y: &[f32] = &[
|
||||
0.0, 1.0, 0.0,
|
||||
0.0, 1.0, 0.0,
|
||||
0.0, 1.0, 0.0,
|
||||
|
||||
0.0, 0.0, 0.0,
|
||||
1.0, 1.0, 1.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0,
|
||||
0.0,
|
||||
];
|
||||
|
||||
|
||||
let w: &mut [f32] = &mut [
|
||||
0.02, -0.01, 0.01,
|
||||
-0.0012, 0.001, -0.005,
|
||||
0.021, -0.0001, 0.008,
|
||||
|
||||
0.021, -0.0001, 0.008,
|
||||
-0.0012, 0.001, -0.005,
|
||||
0.02, -0.01, 0.01,
|
||||
0.02, -0.01, 0.01, -0.0012, 0.001, -0.005, 0.021, -0.0001, 0.008, 0.021, -0.0001,
|
||||
0.008, -0.0012, 0.001, -0.005, 0.02, -0.01, 0.01,
|
||||
];
|
||||
|
||||
|
||||
for _ in 0..10 {
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
let dw: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
let dy: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
conv2d_forward(y, x, w, 1, 1, 2, 5, 5, 3, 3, 1, 1);
|
||||
|
||||
for i in 0..y.len() {
|
||||
@ -455,20 +448,34 @@ mod test {
|
||||
}
|
||||
|
||||
conv2d_grads(dw, x, dy, 1, 1, 2, 5, 5, 3, 3, 1, 1);
|
||||
|
||||
|
||||
for i in 0..w.len() {
|
||||
w[i] -= dw[i] * 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(w, &[
|
||||
0.018572275, 0.16292913, 0.011668432,
|
||||
0.0014115912, 0.17488956, 0.00011491659,
|
||||
0.018765995, 0.17117187, 0.009149003,
|
||||
|
||||
0.018765992, 0.0035881882, 0.009149002,
|
||||
0.16899528, 0.17488956, 0.16769859,
|
||||
0.018572278, -0.004654532, 0.011668428
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
w,
|
||||
&[
|
||||
0.018572275,
|
||||
0.16292913,
|
||||
0.011668432,
|
||||
0.0014115912,
|
||||
0.17488956,
|
||||
0.00011491659,
|
||||
0.018765995,
|
||||
0.17117187,
|
||||
0.009149003,
|
||||
0.018765992,
|
||||
0.0035881882,
|
||||
0.009149002,
|
||||
0.16899528,
|
||||
0.17488956,
|
||||
0.16769859,
|
||||
0.018572278,
|
||||
-0.004654532,
|
||||
0.011668428
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,15 @@
|
||||
|
||||
fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
c: &mut [f32], ldc: usize)
|
||||
{
|
||||
fn gemm_nn(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: usize,
|
||||
b: &[f32],
|
||||
ldb: usize,
|
||||
c: &mut [f32],
|
||||
ldc: usize,
|
||||
) {
|
||||
let a = &a[0..m * k];
|
||||
let b = &b[0..n * k];
|
||||
let c = &mut c[0..m * n];
|
||||
@ -18,11 +24,18 @@ fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_nt(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
c: &mut [f32], ldc: usize)
|
||||
{
|
||||
fn gemm_nt(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: usize,
|
||||
b: &[f32],
|
||||
ldb: usize,
|
||||
c: &mut [f32],
|
||||
ldc: usize,
|
||||
) {
|
||||
let a = &a[0..m * k];
|
||||
let b = &b[0..n * k];
|
||||
let c = &mut c[0..m * n];
|
||||
@ -40,11 +53,18 @@ fn gemm_nt(m: usize, n: usize, k: usize, alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_tn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
c: &mut [f32], ldc: usize)
|
||||
{
|
||||
fn gemm_tn(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: usize,
|
||||
b: &[f32],
|
||||
ldb: usize,
|
||||
c: &mut [f32],
|
||||
ldc: usize,
|
||||
) {
|
||||
let a = &a[0..m * k];
|
||||
let b = &b[0..n * k];
|
||||
let c = &mut c[0..m * n];
|
||||
@ -60,19 +80,26 @@ fn gemm_tn(m: usize, n: usize, k: usize, alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_tt(m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize,
|
||||
c: &mut [f32], ldc: usize)
|
||||
{
|
||||
fn gemm_tt(
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: usize,
|
||||
b: &[f32],
|
||||
ldb: usize,
|
||||
c: &mut [f32],
|
||||
ldc: usize,
|
||||
) {
|
||||
let a = &a[0..m * k];
|
||||
let b = &b[0..n * k];
|
||||
let c = &mut c[0..m * n];
|
||||
|
||||
|
||||
for i_m in 0..m {
|
||||
for i_n in 0..n {
|
||||
let mut sum = 0.0;
|
||||
|
||||
|
||||
for i_k in 0..k {
|
||||
sum += alpha * a[i_k * lda + i_m] * b[i_n * ldb + i_k];
|
||||
}
|
||||
@ -82,11 +109,21 @@ fn gemm_tt(m: usize, n: usize, k: usize, alpha: f32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gemm(ta: bool, tb: bool, m: usize, n: usize, k: usize, alpha: f32,
|
||||
a: &[f32], lda: usize,
|
||||
b: &[f32], ldb: usize, beta: f32,
|
||||
c: &mut [f32], ldc: usize)
|
||||
{
|
||||
pub fn gemm(
|
||||
ta: bool,
|
||||
tb: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: usize,
|
||||
b: &[f32],
|
||||
ldb: usize,
|
||||
beta: f32,
|
||||
c: &mut [f32],
|
||||
ldc: usize,
|
||||
) {
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
c[i * ldc + j] *= beta;
|
||||
@ -102,4 +139,4 @@ pub fn gemm(ta: bool, tb: bool, m: usize, n: usize, k: usize, alpha: f32,
|
||||
} else {
|
||||
gemm_tt(m, n, k, alpha, a, lda, b, ldb, c, ldc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,22 @@
|
||||
pub fn maxpool2d(y: &mut [f32], x: &[f32],
|
||||
y_rows: isize, y_cols: isize,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn maxpool2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
|
||||
|
||||
let mut max = core::f32::NEG_INFINITY;
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
@ -19,25 +25,32 @@ pub fn maxpool2d(y: &mut [f32], x: &[f32],
|
||||
max = val;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
xi += x_cols;
|
||||
}
|
||||
|
||||
|
||||
y[(y_y * y_cols + y_x) as usize] = max;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maxpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
|
||||
x_rows: isize, x_cols: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize)
|
||||
{
|
||||
pub fn maxpool2d_backward(
|
||||
dx: &mut [f32],
|
||||
x: &[f32],
|
||||
dy: &[f32],
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let dx = &mut dx[0..(x_rows * x_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let dy = &dy[0..(y_rows * y_cols) as usize];
|
||||
|
||||
|
||||
for dy_y in 0..y_rows {
|
||||
for dy_x in 0..y_cols {
|
||||
let mut xi = s_row * dy_y * x_cols + s_col * dy_x;
|
||||
@ -55,54 +68,67 @@ pub fn maxpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
|
||||
}
|
||||
xi += x_cols;
|
||||
}
|
||||
|
||||
|
||||
dx[max_idx as usize] = dy[(dy_y * y_cols + dy_x) as usize];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn avgpool2d(y: &mut [f32], x: &[f32],
|
||||
y_rows: isize, y_cols: isize,
|
||||
x_rows: isize, x_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize) {
|
||||
|
||||
pub fn avgpool2d(
|
||||
y: &mut [f32],
|
||||
x: &[f32],
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let w_size = w_rows * w_cols;
|
||||
|
||||
let y = &mut y[0..(y_rows * y_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
|
||||
|
||||
for y_y in 0..y_rows {
|
||||
for y_x in 0..y_cols {
|
||||
let mut xi = s_row * y_y * x_cols + s_col * y_x;
|
||||
|
||||
|
||||
let mut sum = 0.0;
|
||||
|
||||
for _ in 0..w_rows {
|
||||
for w_x in 0..w_cols {
|
||||
sum += x[(xi + w_x) as usize];
|
||||
}
|
||||
|
||||
|
||||
xi += x_cols;
|
||||
}
|
||||
|
||||
|
||||
y[(y_y * y_cols + y_x) as usize] = sum / w_size as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn avgpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
|
||||
x_rows: isize, x_cols: isize,
|
||||
y_rows: isize, y_cols: isize,
|
||||
w_rows: isize, w_cols: isize,
|
||||
s_row: isize, s_col: isize)
|
||||
{
|
||||
pub fn avgpool2d_backward(
|
||||
dx: &mut [f32],
|
||||
x: &[f32],
|
||||
dy: &[f32],
|
||||
x_rows: isize,
|
||||
x_cols: isize,
|
||||
y_rows: isize,
|
||||
y_cols: isize,
|
||||
w_rows: isize,
|
||||
w_cols: isize,
|
||||
s_row: isize,
|
||||
s_col: isize,
|
||||
) {
|
||||
let dx = &mut dx[0..(x_rows * x_cols) as usize];
|
||||
let x = &x[0..(x_rows * x_cols) as usize];
|
||||
let dy = &dy[0..(y_rows * y_cols) as usize];
|
||||
|
||||
|
||||
for dy_y in 0..y_rows {
|
||||
for dy_x in 0..y_cols {
|
||||
let mut xi = s_row * dy_y * x_cols + s_col * dy_x;
|
||||
@ -120,13 +146,12 @@ pub fn avgpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
|
||||
}
|
||||
xi += x_cols;
|
||||
}
|
||||
|
||||
|
||||
dx[max_idx as usize] = dy[(dy_y * y_cols + dy_x) as usize];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -134,65 +159,41 @@ mod tests {
|
||||
#[test]
|
||||
fn test_maxpool2d() {
|
||||
let x: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
|
||||
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
|
||||
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
|
||||
];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
let y: &mut [f32] = &mut [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
maxpool2d(y, x, 3, 3, 6, 6, 2, 2, 2, 2);
|
||||
|
||||
assert_eq!(y, &[
|
||||
8.0, 10.0, 12.0,
|
||||
20.0, 22.0, 24.0,
|
||||
32.0, 34.0, 36.0,
|
||||
])
|
||||
assert_eq!(y, &[8.0, 10.0, 12.0, 20.0, 22.0, 24.0, 32.0, 34.0, 36.0,])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_maxpool2d_backward() {
|
||||
let x: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
|
||||
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
|
||||
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
|
||||
];
|
||||
|
||||
let dy: &[f32] = &[
|
||||
9.0, 8.0, 7.0,
|
||||
6.0, 5.0, 4.0,
|
||||
3.0, 2.0, 1.0
|
||||
];
|
||||
|
||||
|
||||
let dy: &[f32] = &[9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
|
||||
let dx: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
maxpool2d_backward(dx, x, dy, 6, 6, 3, 3, 2, 2, 2, 2);
|
||||
|
||||
let tt: &[f32] = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 9.0, 0.0, 8.0, 0.0, 7.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 6.0, 0.0, 5.0, 0.0, 4.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 3.0, 0.0, 2.0, 0.0, 1.0
|
||||
];
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 8.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 6.0, 0.0, 5.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 2.0,
|
||||
0.0, 1.0,
|
||||
];
|
||||
|
||||
assert_eq!(dx, tt);
|
||||
}
|
||||
@ -200,26 +201,15 @@ mod tests {
|
||||
#[test]
|
||||
fn test_avgpool2d() {
|
||||
let x: &[f32] = &[
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
|
||||
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
|
||||
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
|
||||
];
|
||||
|
||||
let y: &mut [f32] = &mut [
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
|
||||
let y: &mut [f32] = &mut [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
avgpool2d(y, x, 3, 3, 6, 6, 2, 2, 2, 2);
|
||||
|
||||
assert_eq!(y, &[
|
||||
4.5, 6.5, 8.5,
|
||||
16.5, 18.5, 20.5,
|
||||
28.5, 30.5, 32.5,
|
||||
])
|
||||
assert_eq!(y, &[4.5, 6.5, 8.5, 16.5, 18.5, 20.5, 28.5, 30.5, 32.5,])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
use crate::tensor::TensorShape;
|
||||
use crate::backend::Backend;
|
||||
|
||||
use crate::tensor::TensorShape;
|
||||
|
||||
pub trait OptimizerContext {
|
||||
fn new<S: Into<TensorShape>>(shape: S) -> Self;
|
||||
@ -9,14 +8,26 @@ pub trait OptimizerContext {
|
||||
pub trait Optimizer<N, B: Backend<N>> {
|
||||
type Context: OptimizerContext;
|
||||
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor);
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
);
|
||||
}
|
||||
|
||||
impl <'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
|
||||
impl<'a, N, B: Backend<N>, O: Optimizer<N, B>> Optimizer<N, B> for &'a O {
|
||||
type Context = O::Context;
|
||||
|
||||
#[inline]
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
) {
|
||||
(**self).update_params(backend, ctx, params, grads)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,12 @@
|
||||
use crate::backend::{Backend, BackendAdam};
|
||||
use crate::optimizer::{Optimizer, OptimizerContext};
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
use core::cell::Cell;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub struct AdamContext<N, B>
|
||||
where B: Backend<N>
|
||||
pub struct AdamContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
moms: B::Tensor,
|
||||
vels: B::Tensor,
|
||||
@ -34,11 +34,12 @@ pub struct Adam<N, B: Backend<N>> {
|
||||
epsilon: Option<f32>,
|
||||
amsgrad: bool,
|
||||
iteration: Cell<f32>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> Default for Adam<N, B>
|
||||
where B: Backend<N>
|
||||
impl<N, B> Default for Adam<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -70,12 +71,19 @@ impl<N, B: Backend<N>> Adam<N, B> {
|
||||
impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for Adam<N, B> {
|
||||
type Context = AdamContext<N, B>;
|
||||
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
) {
|
||||
let iter = self.iteration.get();
|
||||
let t = iter + 1.0;
|
||||
self.iteration.set(iter + 0.25);
|
||||
|
||||
let lr_t = self.learning_rate * ((1.0 - self.beta_2.powf(t)).sqrt() / (1.0 - self.beta_1.powf(t)));
|
||||
let lr_t =
|
||||
self.learning_rate * ((1.0 - self.beta_2.powf(t)).sqrt() / (1.0 - self.beta_1.powf(t)));
|
||||
|
||||
// m_t = (self.beta_1 * m) + (1. - self.beta_1) * g;
|
||||
backend.scale(&mut ctx.moms, backend.scalar_f32(self.beta_1));
|
||||
@ -87,10 +95,22 @@ impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for Adam<N, B> {
|
||||
|
||||
if self.amsgrad {
|
||||
backend.maximum(&mut ctx.vhats, &ctx.vels);
|
||||
backend.adam_p(params, backend.scalar_f32(-lr_t), &ctx.moms, &ctx.vhats, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
|
||||
backend.adam_p(
|
||||
params,
|
||||
backend.scalar_f32(-lr_t),
|
||||
&ctx.moms,
|
||||
&ctx.vhats,
|
||||
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
|
||||
);
|
||||
} else {
|
||||
// p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
|
||||
backend.adam_p(params, backend.scalar_f32(-lr_t), &ctx.moms, &ctx.vels, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
|
||||
backend.adam_p(
|
||||
params,
|
||||
backend.scalar_f32(-lr_t),
|
||||
&ctx.moms,
|
||||
&ctx.vels,
|
||||
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,19 @@
|
||||
mod sgd;
|
||||
mod adam;
|
||||
mod rmsprop;
|
||||
mod sgd;
|
||||
|
||||
pub use self::sgd::*;
|
||||
pub use self::adam::*;
|
||||
pub use self::rmsprop::*;
|
||||
pub use self::sgd::*;
|
||||
|
||||
use crate::backend::{Backend, BackendAxpys};
|
||||
use crate::optimizer::Optimizer;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
pub struct WeightDecay<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
lamda: f32,
|
||||
optimizer: O,
|
||||
@ -20,8 +21,9 @@ pub struct WeightDecay<N, B, O>
|
||||
}
|
||||
|
||||
impl<N, B, O> WeightDecay<N, B, O>
|
||||
where B: Backend<N>,
|
||||
O: Optimizer<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
pub fn new(lamda: f32, optimizer: O) -> Self {
|
||||
Self {
|
||||
@ -32,14 +34,21 @@ impl<N, B, O> WeightDecay<N, B, O>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O>
|
||||
where B: Backend<N> + BackendAxpys<N>,
|
||||
O: Optimizer<N, B>
|
||||
impl<N, B, O> Optimizer<N, B> for WeightDecay<N, B, O>
|
||||
where
|
||||
B: Backend<N> + BackendAxpys<N>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
type Context = O::Context;
|
||||
|
||||
#[inline]
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
) {
|
||||
backend.axpys(grads, backend.scalar_f32(self.lamda), params);
|
||||
|
||||
self.optimizer.update_params(backend, ctx, params, grads);
|
||||
|
@ -3,9 +3,9 @@ use crate::optimizer::{Optimizer, OptimizerContext};
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub struct RMSPropContext<N, B>
|
||||
where B: Backend<N>
|
||||
pub struct RMSPropContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
accum: B::Tensor,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
@ -26,11 +26,12 @@ pub struct RMSProp<N, B: Backend<N>> {
|
||||
learning_rate: f32,
|
||||
rho: f32,
|
||||
epsilon: Option<f32>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> Default for RMSProp<N, B>
|
||||
where B: Backend<N>
|
||||
impl<N, B> Default for RMSProp<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -56,12 +57,24 @@ impl<N, B: Backend<N>> RMSProp<N, B> {
|
||||
impl<N, B: Backend<N> + BackendAdam<N>> Optimizer<N, B> for RMSProp<N, B> {
|
||||
type Context = RMSPropContext<N, B>;
|
||||
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
) {
|
||||
// new_a = self.rho * a + (1. - self.rho) * K.square(g)
|
||||
backend.scale(&mut ctx.accum, backend.scalar_f32(self.rho));
|
||||
backend.axpys(&mut ctx.accum, backend.scalar_f32(1.0 - self.rho), grads);
|
||||
|
||||
|
||||
// new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
|
||||
backend.adam_p(params, backend.scalar_f32(-self.learning_rate), &grads, &ctx.accum, backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)));
|
||||
backend.adam_p(
|
||||
params,
|
||||
backend.scalar_f32(-self.learning_rate),
|
||||
&grads,
|
||||
&ctx.accum,
|
||||
backend.scalar_f32(self.epsilon.unwrap_or(core::f32::EPSILON)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
use crate::backend::{Backend, BackendScale, BackendAxpy, BackendAdd};
|
||||
use crate::backend::{Backend, BackendAdd, BackendAxpy, BackendScale};
|
||||
use crate::optimizer::{Optimizer, OptimizerContext};
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
||||
pub struct SgdContext<N, B>
|
||||
where B: Backend<N>
|
||||
pub struct SgdContext<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
moments: B::Tensor,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
@ -24,11 +24,12 @@ pub struct Sgd<N, B: Backend<N>> {
|
||||
learning_rate: f32,
|
||||
momentum: f32,
|
||||
nesterov: bool,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
_m: PhantomData<fn(N, B)>,
|
||||
}
|
||||
|
||||
impl<N, B> Default for Sgd<N, B>
|
||||
where B: Backend<N>
|
||||
impl<N, B> Default for Sgd<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -40,8 +41,9 @@ impl<N, B> Default for Sgd<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B> Sgd<N, B>
|
||||
where B: Backend<N>
|
||||
impl<N, B> Sgd<N, B>
|
||||
where
|
||||
B: Backend<N>,
|
||||
{
|
||||
pub fn new(learning_rate: f32, momentum: f32, nesterov: bool) -> Self {
|
||||
Self {
|
||||
@ -53,13 +55,25 @@ impl<N, B> Sgd<N, B>
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, B: Backend<N> + BackendScale<N> + BackendAxpy<N> + BackendAdd<N>> Optimizer<N, B> for Sgd<N, B> {
|
||||
impl<N, B: Backend<N> + BackendScale<N> + BackendAxpy<N> + BackendAdd<N>> Optimizer<N, B>
|
||||
for Sgd<N, B>
|
||||
{
|
||||
type Context = SgdContext<N, B>;
|
||||
|
||||
fn update_params(&self, backend: &B, ctx: &mut Self::Context, params: &mut B::Tensor, grads: &mut B::Tensor) {
|
||||
fn update_params(
|
||||
&self,
|
||||
backend: &B,
|
||||
ctx: &mut Self::Context,
|
||||
params: &mut B::Tensor,
|
||||
grads: &mut B::Tensor,
|
||||
) {
|
||||
// m = momentum * m - lr * grads
|
||||
backend.scale(&mut ctx.moments, backend.scalar_f32(self.momentum));
|
||||
backend.axpy(&mut ctx.moments, backend.scalar_f32(-self.learning_rate), grads);
|
||||
backend.axpy(
|
||||
&mut ctx.moments,
|
||||
backend.scalar_f32(-self.learning_rate),
|
||||
grads,
|
||||
);
|
||||
|
||||
if self.nesterov {
|
||||
// p += momentum * m - lr * grads
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::Backend;
|
||||
use crate::optimizer::{Optimizer, OptimizerContext};
|
||||
use crate::tensor::{TensorShape, Tensor};
|
||||
use crate::tensor::{Tensor, TensorShape};
|
||||
|
||||
pub struct Params<N, B: Backend<N>, O: Optimizer<N, B>> {
|
||||
pub params: B::Tensor,
|
||||
|
@ -1,11 +1,10 @@
|
||||
|
||||
use core::fmt;
|
||||
|
||||
pub struct TensorShapeIter<'a> {
|
||||
shape: &'a TensorShape,
|
||||
left: usize,
|
||||
right: usize,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for TensorShapeIter<'a> {
|
||||
type Item = u32;
|
||||
@ -50,7 +49,7 @@ pub struct TensorShape {
|
||||
impl fmt::Display for TensorShape {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "(")?;
|
||||
for i in 0 .. self.dims {
|
||||
for i in 0..self.dims {
|
||||
if i != 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
@ -63,7 +62,7 @@ impl fmt::Display for TensorShape {
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorShape {
|
||||
impl TensorShape {
|
||||
#[inline]
|
||||
pub fn zero() -> Self {
|
||||
TensorShape {
|
||||
@ -71,7 +70,7 @@ impl TensorShape {
|
||||
dims: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn new0d() -> Self {
|
||||
TensorShape {
|
||||
@ -79,7 +78,7 @@ impl TensorShape {
|
||||
dims: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn new1d(w: u32) -> Self {
|
||||
TensorShape {
|
||||
@ -87,7 +86,7 @@ impl TensorShape {
|
||||
dims: 1,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn new2d(h: u32, w: u32) -> Self {
|
||||
TensorShape {
|
||||
@ -95,7 +94,7 @@ impl TensorShape {
|
||||
dims: 2,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn new3d(b: u32, h: u32, w: u32) -> Self {
|
||||
TensorShape {
|
||||
@ -103,7 +102,7 @@ impl TensorShape {
|
||||
dims: 3,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn new4d(b: u32, c: u32, h: u32, w: u32) -> Self {
|
||||
TensorShape {
|
||||
@ -111,7 +110,7 @@ impl TensorShape {
|
||||
dims: 4,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn iter(&self) -> TensorShapeIter<'_> {
|
||||
TensorShapeIter {
|
||||
@ -125,7 +124,7 @@ impl TensorShape {
|
||||
let s = s.borrow();
|
||||
let sd = self.dims;
|
||||
|
||||
for i in 0 .. s.dims {
|
||||
for i in 0..s.dims {
|
||||
self.shape[i + sd] = s.shape[i];
|
||||
}
|
||||
|
||||
@ -133,12 +132,12 @@ impl TensorShape {
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self, index: usize) -> u32 {
|
||||
self.shape[index]
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn set(&mut self, index: usize, val: u32) {
|
||||
self.shape[index] = val;
|
||||
@ -155,17 +154,14 @@ impl TensorShape {
|
||||
}
|
||||
}
|
||||
|
||||
TensorShape {
|
||||
shape,
|
||||
dims
|
||||
}
|
||||
TensorShape { shape, dims }
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn size(&self) -> usize {
|
||||
let mut product = 1;
|
||||
|
||||
for i in 0 .. self.dims {
|
||||
|
||||
for i in 0..self.dims {
|
||||
product *= self.shape[i] as usize;
|
||||
}
|
||||
|
||||
@ -176,21 +172,24 @@ impl TensorShape {
|
||||
let mut strides = [0; 4];
|
||||
let mut product = 1;
|
||||
|
||||
for i in 0..self.dims {
|
||||
for i in 0..self.dims {
|
||||
let si = self.dims - i - 1;
|
||||
|
||||
strides[si] = product;
|
||||
product *= self.shape[si];
|
||||
product *= self.shape[si];
|
||||
}
|
||||
|
||||
TensorShape { shape: strides, dims: self.dims }
|
||||
TensorShape {
|
||||
shape: strides,
|
||||
dims: self.dims,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn as_slice(&self) -> &[u32] {
|
||||
&self.shape[0..self.dims]
|
||||
}
|
||||
|
||||
|
||||
#[inline]
|
||||
pub fn last_axis(&self) -> u32 {
|
||||
self.shape[self.dims - 1]
|
||||
@ -206,8 +205,8 @@ impl From<()> for TensorShape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(u32, )> for TensorShape {
|
||||
fn from(x: (u32, )) -> Self {
|
||||
impl From<(u32,)> for TensorShape {
|
||||
fn from(x: (u32,)) -> Self {
|
||||
TensorShape {
|
||||
shape: [x.0, 0, 0, 0],
|
||||
dims: 1,
|
||||
@ -227,7 +226,7 @@ impl From<(u32, u32)> for TensorShape {
|
||||
impl From<(u32, u32, u32)> for TensorShape {
|
||||
fn from(x: (u32, u32, u32)) -> Self {
|
||||
TensorShape {
|
||||
shape: [x.0 , x.1, x.2, 0],
|
||||
shape: [x.0, x.1, x.2, 0],
|
||||
dims: 3,
|
||||
}
|
||||
}
|
||||
@ -236,7 +235,7 @@ impl From<(u32, u32, u32)> for TensorShape {
|
||||
impl From<(u32, u32, u32, u32)> for TensorShape {
|
||||
fn from(x: (u32, u32, u32, u32)) -> Self {
|
||||
TensorShape {
|
||||
shape: [x.0 , x.1, x.2, x.3],
|
||||
shape: [x.0, x.1, x.2, x.3],
|
||||
dims: 4,
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user