Convolution optimizations

This commit is contained in:
Andrey Tkachenko 2019-07-15 17:13:22 +04:00
parent 1ccc5f45c9
commit 0ff1204276
11 changed files with 1010 additions and 420 deletions

295
Cargo.lock generated
View File

@ -6,23 +6,30 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "blas"
version = "0.20.0"
name = "backtrace"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "blas-sys"
version = "0.7.1"
name = "backtrace-sys"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "bumpalo"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "byteorder"
version = "1.3.2"
@ -38,21 +45,33 @@ dependencies = [
]
[[package]]
name = "cblas"
version = "0.2.0"
name = "cc"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "cfg-if"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "failure"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)",
"failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cblas-sys"
version = "0.1.4"
name = "failure_derive"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -64,6 +83,22 @@ dependencies = [
"libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "heck"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "js-sys"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "lazy_static"
version = "1.3.0"
@ -77,6 +112,19 @@ name = "libc"
version = "0.2.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "log"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "memchr"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "mnist"
version = "0.4.0"
@ -86,32 +134,35 @@ dependencies = [
]
[[package]]
name = "num-complex"
version = "0.2.3"
name = "nom"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-traits"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "openblas-src"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ppv-lite86"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro2"
version = "0.4.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "quote"
version = "0.6.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand"
version = "0.7.0"
@ -158,11 +209,140 @@ dependencies = [
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rustc-demangle"
version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "sourcefile"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "spin"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "syn"
version = "0.15.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "synstructure"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "unicode-segmentation"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "unicode-xid"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "version_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasm-bindgen"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasm-bindgen-webidl"
version = "0.2.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "web-sys"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "weedle"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "yarnn"
version = "0.1.0"
@ -180,6 +360,16 @@ dependencies = [
"yarnn-model-mnist 0.1.0",
]
[[package]]
name = "yarnn-example-mnist-wasm"
version = "0.1.0"
dependencies = [
"js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)",
"yarnn 0.1.0",
]
[[package]]
name = "yarnn-example-vgg16-demo"
version = "0.1.0"
@ -202,34 +392,51 @@ dependencies = [
"yarnn 0.1.0",
]
[[package]]
[[patch.unused]]
name = "yarnn-native-blas"
version = "0.1.0"
dependencies = [
"blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)",
"cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[metadata]
"checksum autocfg 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "0e49efa51329a5fd37e7c79db4621af617cd4e3e5bc224939808d076077077bf"
"checksum blas 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4b3b6399f2fe6e70f780b06e278a64770aae273e0005c3860f94fc2f5f01ba7"
"checksum blas-sys 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "13b1b279ceb25d7c4faaea95a5f7addbe7d8c34f9462044bd8e630cebcfc2440"
"checksum backtrace 0.3.32 (registry+https://github.com/rust-lang/crates.io-index)" = "18b50f5258d1a9ad8396d2d345827875de4261b158124d4c819d9b351454fae5"
"checksum backtrace-sys 0.1.30 (registry+https://github.com/rust-lang/crates.io-index)" = "5b3a000b9c543553af61bc01cbfc403b04b5caa9e421033866f2e98061eb3e61"
"checksum bumpalo 2.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2cd43d82f27d68911e6ee11ee791fb248f138f5d69424dc02e098d4f152b0b05"
"checksum byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5"
"checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101"
"checksum cblas 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d82f331add33eceb4c41cb28d878049b96f56577016daf190831e94e4aece5db"
"checksum cblas-sys 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
"checksum cc 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "39f75544d7bbaf57560d2168f28fd649ff9c76153874db88bdbdfd839b1a7e7d"
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
"checksum getrandom 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "e65cce4e5084b14874c4e7097f38cab54f47ee554f9194673456ea379dcc4c55"
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
"checksum js-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "da3ea71161651a4cd97d999b2da139109c537b15ab33abc8ae4ead38deac8a03"
"checksum lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bc5729f27f159ddd61f4df6228e827e86643d4d3e7c32183cb30a1c08f604a14"
"checksum libc 0.2.59 (registry+https://github.com/rust-lang/crates.io-index)" = "3262021842bf00fe07dbd6cf34ff25c99d7a7ebef8deea84db72be3ea3bb0aff"
"checksum log 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "c275b6ad54070ac2d665eef9197db647b32239c9d244bfb6f041a766d00da5b3"
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
"checksum mnist 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "25f19bfda80095b4294000bbb50506f028149ed0ddb7fabf46ebb673b91626bc"
"checksum num-complex 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fcb0cf31fb3ff77e6d2a6ebd6800df7fdcd106f2ad89113c9130bcd07f93dffc"
"checksum num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32"
"checksum openblas-src 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3533e568814bee9620fcc529158408384404bae5b277c73c73d66ca03fceb7"
"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
"checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b"
"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
"checksum rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d47eab0e83d9693d40f825f86948aa16eff6750ead4bdffc4ab95b8b3a7f052c"
"checksum rand_chacha 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e193067942ef6f485a349a113329140d0ab9e2168ce92274499bb0e9a4190d9d"
"checksum rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "615e683324e75af5d43d8f7a39ffe3ee4a9dc42c5c701167a71dc59c3a493aca"
"checksum rand_distr 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c37e54d811c5a51195156444e298a98ba57df6dca1e511f34d4791993883b921"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rustc-demangle 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "a7f4dccf6f4891ebcc0c39f9b6eb1a83b9bf5d747cb439ec6fba4f3b977038af"
"checksum sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4bf77cb82ba8453b42b6ae1d692e4cdc92f9a47beaf89a847c8be83f4e328ad3"
"checksum spin 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44363f6f51401c34e7be73db0db371c04705d35efbe9f7d6082e03a921a32c55"
"checksum syn 0.15.39 (registry+https://github.com/rust-lang/crates.io-index)" = "b4d960b829a55e56db167e861ddb43602c003c7be0bee1d345021703fac2fb7c"
"checksum synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "02353edf96d6e4dc81aea2d8490a7e9db177bf8acb0e951c24940bf866cb313f"
"checksum unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1967f4cdfc355b37fd76d2a954fb2ed3871034eb4f26d60537d88795cfc332a9"
"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
"checksum wasm-bindgen 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "4de97fa1806bb1a99904216f6ac5e0c050dc4f8c676dc98775047c38e5c01b55"
"checksum wasm-bindgen-backend 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "5d82c170ef9f5b2c63ad4460dfcee93f3ec04a9a36a4cc20bc973c39e59ab8e3"
"checksum wasm-bindgen-macro 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "f07d50f74bf7a738304f6b8157f4a581e1512cd9e9cdb5baad8c31bbe8ffd81d"
"checksum wasm-bindgen-macro-support 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "95cf8fe77e45ba5f91bc8f3da0c3aa5d464b3d8ed85d84f4d4c7cc106436b1d7"
"checksum wasm-bindgen-shared 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c2d4d4756b2e46d3a5422e06277d02e4d3e1d62d138b76a4c681e925743623"
"checksum wasm-bindgen-webidl 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "24e47859b4eba3d3b9a5c2c299f9d6f8d0b613671315f6f0c5c7f835e524b36a"
"checksum web-sys 0.3.25 (registry+https://github.com/rust-lang/crates.io-index)" = "86d515d2f713d3a6ab198031d2181b7540f8e319e4637ec2d4a41a208335ef29"
"checksum weedle 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3bb43f70885151e629e2a19ce9e50bd730fd436cfd4b666894c9ce4de9141164"

View File

@ -2,13 +2,13 @@
members = [
"yarnn",
"yarnn-native-blas",
# "yarnn-native-blas",
"yarnn-models/mnist",
"yarnn-models/vgg16",
"yarnn-examples/mnist",
# "yarnn-examples/mnist-wasm",
"yarnn-examples/mnist-wasm",
# "yarnn-examples/mnist-blas",
"yarnn-examples/vgg16-demo",
]

View File

@ -6,6 +6,7 @@ edition = "2018"
build = "build.rs"
[dependencies]
yarnn = "0.1.0"
openblas-src = {version = "0.7.0", features = ["system"]}
blas = "0.20.0"
cblas = "0.2.0"

View File

@ -1,164 +1,173 @@
// pub struct NativeBlas {
// inner: Native,
// }
use yarnn::backend::*;
// impl NativeBlas {
extern crate openblas_src;
// }
pub struct NativeBlas<N, B: Native> {
inner: B,
}
// impl Backend<f32> for NativeBlas {
// type Tensor = NativeTensorF32;
impl<N, B: Native> Native for NativeBlas<N, B> {}
// #[inline]
// fn store_tensor_f32(&self, t: &Self::Tensor, data: &mut [f32]) {
// self.inner.store_tensor_f32(t, data);
// }
impl<N, B> NativeBlas<N, B>
where N: NativeTensor,
B: NativeBackend<N>
{
pub fn new(native: B) -> Self {
Self {
inner: native
}
}
}
// #[inline]
// fn load_tensor_u8(&self, t: &mut Self::Tensor, data: &[u8]) {
// self.inner.load_tensor_u8(t, data);
// }
impl<N, B> Backend<N> for NativeBlas<N, B>
where N: NativeTensor,
B: NativeBackend<N>
{
type Tensor = B::Tensor;
// #[inline]
// fn load_tensor_f32(&self, t: &mut Self::Tensor, data: &[f32]) {
// self.inner.load_tensor_f32(t, data);
// }
fn store_tensor_f32(&self, t: &Self::Tensor, data: &mut [f32]) {
// #[inline]
// fn scalar_f32(&self, val: f32) -> N {
// self.inner.scalar_f32(val);
// }
}
fn load_tensor_u8(&self, t: &mut Self::Tensor, data: &[u8]) {
// #[inline]
// fn fill_scalar(&self, t: &mut Self::Tensor, scalar: N) {
// self.inner.scalar_f32(t, scalar);
// }
}
fn load_tensor_f32(&self, t: &mut Self::Tensor, data: &[f32]) {
// #[inline]
// fn fill_random(&self, t: &mut Self::Tensor, from: N, to: N) {
// self.inner.scalar_f32(t, from, to);
// }
}
fn scalar_f32(&self, val: f32) -> N {
// #[inline]
// fn print_tensor(&self, t: &Self::Tensor) {
// self.inner.scalar_f32(t);
// }
// }
}
fn fill_scalar(&self, t: &mut Self::Tensor, scalar: N) {
// impl BackendGemm<f32> for Native {
// fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
// let a_shape = a.shape();
// let b_shape = b.shape();
// let c_shape = dst.shape().clone();
}
fn fill_random(&self, t: &mut Self::Tensor, from: N, to: N) {
// assert_eq!(a_shape.get(0), c_shape.get(0));
// assert_eq!(b_shape.get(1), c_shape.get(1));
}
fn print_tensor(&self, t: &Self::Tensor) {
// assert_eq!(a_shape.dims, 2);
// assert_eq!(b_shape.dims, 2);
}
}
// let m = a_shape.get(0) as usize;
// let n = b_shape.get(1) as usize;
// let k = b_shape.get(0) as usize;
impl<B> BackendGemm<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
{
fn matmul(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = dst.shape().clone();
assert_eq!(a_shape.get(0), c_shape.get(0));
assert_eq!(b_shape.get(1), c_shape.get(1));
assert_eq!(a_shape.dims, 2);
assert_eq!(b_shape.dims, 2);
let m = a_shape.get(0) as usize;
let n = b_shape.get(1) as usize;
let k = b_shape.get(0) as usize;
// unsafe {
// sgemm('N' as u8, 'N' as u8,
// n, m, k,
// 1.0,
// b.read(), n,
// a.read(), k,
// 0.0,
// &mut dst.write(), n);
// }
// }
unsafe {
blas::sgemm('N' as u8, 'N' as u8,
n, m, k,
1.0,
b.read(), n,
a.read(), k,
0.0,
&mut dst.write(), n);
}
}
// fn matmul_nt(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
// let a_shape = a.shape();
// let b_shape = b.shape();
// let c_shape = dst.shape().clone();
fn matmul_nt(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = dst.shape().clone();
// assert_eq!(a_shape.get(0), c_shape.get(0));
// assert_eq!(b_shape.get(0), c_shape.get(1));
assert_eq!(a_shape.get(0), c_shape.get(0));
assert_eq!(b_shape.get(0), c_shape.get(1));
// assert_eq!(a_shape.dims, 2);
// assert_eq!(b_shape.dims, 2);
assert_eq!(a_shape.dims, 2);
assert_eq!(b_shape.dims, 2);
// let m = a_shape.get(0) as usize;
// let n = b_shape.get(0) as usize;
// let k = b_shape.get(1) as usize;
let m = a_shape.get(0) as usize;
let n = b_shape.get(0) as usize;
let k = b_shape.get(1) as usize;
// unsafe {
// sgemm('T' as u8, 'N' as u8,
// n, m, k,
// 1.0,
// b.read(), k,
// a.read(), k,
// 0.0,
// &mut dst.write(), n);
// }
// }
unsafe {
blas::sgemm('T' as u8, 'N' as u8,
n, m, k,
1.0,
b.read(), k,
a.read(), k,
0.0,
&mut dst.write(), n);
}
}
// fn matmul_tn(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
// let a_shape = a.shape();
// let b_shape = b.shape();
// let c_shape = dst.shape().clone();
fn matmul_tn(&self, dst: &mut Self::Tensor, a: &Self::Tensor, b: &Self::Tensor) {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = dst.shape().clone();
// assert_eq!(a_shape.get(1), c_shape.get(0));
// assert_eq!(b_shape.get(1), c_shape.get(1));
assert_eq!(a_shape.get(1), c_shape.get(0));
assert_eq!(b_shape.get(1), c_shape.get(1));
// assert_eq!(a_shape.dims, 2);
// assert_eq!(b_shape.dims, 2);
assert_eq!(a_shape.dims, 2);
assert_eq!(b_shape.dims, 2);
// let m = a_shape.get(1) as usize;
// let n = b_shape.get(1) as usize;
// let k = b_shape.get(0) as usize;
let m = a_shape.get(1) as usize;
let n = b_shape.get(1) as usize;
let k = b_shape.get(0) as usize;
// unsafe {
// sgemm('N' as u8, 'T' as u8,
// n, m, k,
// 1.0,
// b.read(), n,
// a.read(), m,
// 0.0,
// &mut dst.write(), n);
// }
// }
unsafe {
blas::sgemm('N' as u8, 'T' as u8,
n, m, k,
1.0,
b.read(), n,
a.read(), m,
0.0,
&mut dst.write(), n);
}
}
// fn matmul_tt(&self, _dst: &mut Self::Tensor, _a: &Self::Tensor, _b: &Self::Tensor) {
// unimplemented!();
// }
// }
fn matmul_tt(&self, _dst: &mut Self::Tensor, _a: &Self::Tensor, _b: &Self::Tensor) {
unimplemented!();
}
}
// impl BackendAxpy<f32> for Native {
// fn axpy(&self, dst: &mut Self::Tensor, scale: f32, x: &Self::Tensor) {
// let dst_size = dst.shape().size();
impl<B> BackendAxpy<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
{
fn axpy(&self, dst: &mut Self::Tensor, scale: f32, x: &Self::Tensor) {
let dst_size = dst.shape().size();
// assert!(x.shape() == dst.shape());
assert!(x.shape() == dst.shape());
// unsafe {
// blas::saxpy(
// dst_size as i32,
// scale,
// x.read(),
// 1,
// dst.write(),
// 1
// );
// }
// }
// }
unsafe {
blas::saxpy(
dst_size as i32,
scale,
x.read(),
1,
dst.write(),
1
);
}
}
}
// impl BackendScale<f32> for Native {
// fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
// let dst_size = dst.shape().size();
impl<B> BackendScale<f32> for NativeBlas<f32, B>
where B: NativeBackend<f32>
{
fn scale(&self, dst: &mut Self::Tensor, scale: f32) {
let dst_size = dst.shape().size();
// unsafe {
// blas::sscal(
// dst_size as i32,
// scale,
// dst.write(),
// 1
// );
// }
// }
// }
unsafe {
blas::sscal(
dst_size as i32,
scale,
dst.write(),
1
);
}
}
}

View File

@ -340,3 +340,14 @@ impl <'a, N, T: BackendAvgPool2d<N>> BackendAvgPool2d<N> for &'a T {
(**self).avg_pool2d_backprop(dx, dy, x, conv_info)
}
}
pub trait BackendPaddingCopy2d<N>: Backend<N> {
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32));
}
impl <'a, N, T: BackendPaddingCopy2d<N>> BackendPaddingCopy2d<N> for &'a T {
#[inline]
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
(**self).copy_with_padding2d(y, x, y_paddings, x_paddings)
}
}

42
yarnn/src/model.rs Normal file
View File

@ -0,0 +1,42 @@
pub trait Model<N, B, O>
where B: Backend<N>,
O: Optimizer<N, B>
{
fn init(&mut self, backend: &B);
fn predict(&mut self, backend: &B, x: &B::Tensor);
fn evaluate(&mut self, backend: &B, x: &B::Tensor, y: &B::Tensor) -> ConfusionMatrix;
fn train(&mut self, backend: &B, optimizer: &O, x: &B::Tensor, y: &B::Tensor);
}
pub struct DefaultModel<N, B, O, L>
where B: Backend<N>,
O: Optimizer<N, B>,
L: AbstractLayer<N, B, O>
{
inner: L,
train_ctx: L::Context,
evaluate_ctx: L::Context,
}
impl<N, B, O, L> Model<N, B, O> for DefaultModel<N, B, O, L>
{
fn init(&mut self, backend: &B) {
}
fn predict(&mut self, backend: &B, x: &B::Tensor) -> &B::Tensor {
self.inner.forward(backend, x, self.evaluate_ctx);
self.evaluate_ctx.outputs()
}
fn evaluate(&mut self, backend: &B, x: &B::Tensor, y: &B::Tensor) -> ConfusionMatrix {
}
fn train(&mut self, backend: &B, optimizer: &O, x: &B::Tensor, y: &B::Tensor) {
}
}

View File

@ -0,0 +1,145 @@
#[allow(dead_code)]
pub fn valid_conv2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
let y_rows = (x_rows - 3) / s_row + 1;
let y_cols = (x_cols - 3) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..9];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut sum = 0.0;
sum += x[(xi + 0) as usize] * w[0];
sum += x[(xi + 1) as usize] * w[1];
sum += x[(xi + 2) as usize] * w[2];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[3];
sum += x[(xi + 1) as usize] * w[4];
sum += x[(xi + 2) as usize] * w[5];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[6];
sum += x[(xi + 1) as usize] * w[7];
sum += x[(xi + 2) as usize] * w[8];
y[(y_y * y_cols + y_x) as usize] += alpha * sum;
}
}
}
#[allow(dead_code)]
pub fn full_xcorr2d_3x3(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
s_row: isize, s_col: isize) {
let y_cols = (x_cols - 1) * s_col + 3;
let y_rows = (x_rows - 1) * s_row + 3;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..9];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
let z = alpha * x[(x_y * x_cols + x_x) as usize];
y[(yi + 0) as usize] += z * w[8];
y[(yi + 1) as usize] += z * w[7];
y[(yi + 2) as usize] += z * w[6];
yi += y_cols;
y[(yi + 0) as usize] += z * w[5];
y[(yi + 1) as usize] += z * w[4];
y[(yi + 2) as usize] += z * w[3];
yi += y_cols;
y[(yi + 0) as usize] += z * w[2];
y[(yi + 1) as usize] += z * w[1];
y[(yi + 2) as usize] += z * w[0];
}
}
}
pub fn conv2d_forward_3x3(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
let y_rows = (x_rows - 3) / s_row + 1;
let y_cols = (x_cols - 3) / s_col + 1;
let x_img_size = x_rows * x_cols;
let y_img_size = y_rows * y_cols;
let w_img_size = 9;
let x_batch_size = x_channels * x_img_size;
let y_batch_size = y_channels * y_img_size;
let y = &mut y[0..(bs * y_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
let x_img = &x[x_offset..x_offset + x_img_size as usize];
for y_ch in 0..y_channels {
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
valid_conv2d_3x3(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
}
}
}
}
pub fn conv2d_backward_3x3(dx: &mut [f32], dy: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
y_rows: isize, y_cols: isize,
s_row: isize, s_col: isize) {
let x_cols = (y_cols - 1) * s_col + 3;
let x_rows = (y_rows - 1) * s_row + 3;
let dx_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let w_img_size = 9;
let dx_batch_size = x_channels * dx_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
for x_ch in 0..x_channels {
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
full_xcorr2d_3x3(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
}
}
}
}

View File

@ -0,0 +1,183 @@
#[allow(dead_code)]
pub fn valid_conv2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
let y_rows = (x_rows - 5) / s_row + 1;
let y_cols = (x_cols - 5) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..25];
for y_y in 0..y_rows {
for y_x in 0..y_cols {
let mut xi = s_row * y_y * x_cols + s_col * y_x;
let mut sum = 0.0;
sum += x[(xi + 0) as usize] * w[0];
sum += x[(xi + 1) as usize] * w[1];
sum += x[(xi + 2) as usize] * w[2];
sum += x[(xi + 3) as usize] * w[3];
sum += x[(xi + 4) as usize] * w[4];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[5];
sum += x[(xi + 1) as usize] * w[6];
sum += x[(xi + 2) as usize] * w[7];
sum += x[(xi + 3) as usize] * w[8];
sum += x[(xi + 4) as usize] * w[9];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[10];
sum += x[(xi + 1) as usize] * w[11];
sum += x[(xi + 2) as usize] * w[12];
sum += x[(xi + 3) as usize] * w[13];
sum += x[(xi + 4) as usize] * w[14];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[15];
sum += x[(xi + 1) as usize] * w[16];
sum += x[(xi + 2) as usize] * w[17];
sum += x[(xi + 3) as usize] * w[18];
sum += x[(xi + 4) as usize] * w[19];
xi += x_cols;
sum += x[(xi + 0) as usize] * w[20];
sum += x[(xi + 1) as usize] * w[21];
sum += x[(xi + 2) as usize] * w[22];
sum += x[(xi + 3) as usize] * w[23];
sum += x[(xi + 4) as usize] * w[24];
y[(y_y * y_cols + y_x) as usize] += alpha * sum;
}
}
}
#[allow(dead_code)]
pub fn full_xcorr2d_5x5(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
s_row: isize, s_col: isize) {
let y_cols = (x_cols - 1) * s_col + 5;
let y_rows = (x_rows - 1) * s_row + 5;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
let w = &w[0..25];
for x_y in 0..x_rows {
for x_x in 0..x_cols {
let mut yi = s_row * x_y * y_cols + s_col * x_x;
let z = alpha * x[(x_y * x_cols + x_x) as usize];
y[(yi + 0) as usize] += z * w[24];
y[(yi + 1) as usize] += z * w[23];
y[(yi + 2) as usize] += z * w[22];
y[(yi + 3) as usize] += z * w[21];
y[(yi + 4) as usize] += z * w[20];
yi += y_cols;
y[(yi + 0) as usize] += z * w[19];
y[(yi + 1) as usize] += z * w[18];
y[(yi + 2) as usize] += z * w[17];
y[(yi + 3) as usize] += z * w[16];
y[(yi + 4) as usize] += z * w[15];
yi += y_cols;
y[(yi + 0) as usize] += z * w[14];
y[(yi + 1) as usize] += z * w[13];
y[(yi + 2) as usize] += z * w[12];
y[(yi + 3) as usize] += z * w[11];
y[(yi + 4) as usize] += z * w[10];
yi += y_cols;
y[(yi + 0) as usize] += z * w[9];
y[(yi + 1) as usize] += z * w[8];
y[(yi + 2) as usize] += z * w[7];
y[(yi + 3) as usize] += z * w[6];
y[(yi + 4) as usize] += z * w[5];
yi += y_cols;
y[(yi + 0) as usize] += z * w[4];
y[(yi + 1) as usize] += z * w[3];
y[(yi + 2) as usize] += z * w[2];
y[(yi + 3) as usize] += z * w[1];
y[(yi + 4) as usize] += z * w[0];
}
}
}
pub fn conv2d_forward_5x5(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize, s_row: isize, s_col: isize) {
let y_rows = (x_rows - 5) / s_row + 1;
let y_cols = (x_cols - 5) / s_col + 1;
let x_img_size = x_rows * x_cols;
let y_img_size = y_rows * y_cols;
let w_img_size = 25;
let x_batch_size = x_channels * x_img_size;
let y_batch_size = y_channels * y_img_size;
let y = &mut y[0..(bs * y_batch_size) as usize];
let x = &x[0..(bs * x_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for x_ch in 0..x_channels {
let x_offset = (bi * x_batch_size + x_ch * x_img_size) as usize;
let x_img = &x[x_offset..x_offset + x_img_size as usize];
for y_ch in 0..y_channels {
let y_offset = (bi * y_batch_size + y_ch * y_img_size) as usize;
let y_img = &mut y[y_offset..y_offset + y_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
valid_conv2d_5x5(y_img, x_img, w, 1.0, x_rows, x_cols, s_row, s_col);
}
}
}
}
pub fn conv2d_backward_5x5(dx: &mut [f32], dy: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
y_rows: isize, y_cols: isize,
s_row: isize, s_col: isize) {
let x_cols = (y_cols - 1) * s_col + 5;
let x_rows = (y_rows - 1) * s_row + 5;
let dx_img_size = x_rows * x_cols;
let dy_img_size = y_rows * y_cols;
let w_img_size = 25;
let dx_batch_size = x_channels * dx_img_size;
let dy_batch_size = y_channels * dy_img_size;
let dx = &mut dx[0..(bs * dx_batch_size) as usize];
let dy = &dy[0..(bs * dy_batch_size) as usize];
let w = &w[0..(y_channels * w_img_size) as usize];
for bi in 0..bs {
for y_ch in 0..y_channels {
let dy_offset = (bi * dy_batch_size + y_ch * dy_img_size) as usize;
let dy_img = &dy[dy_offset..dy_offset + dy_img_size as usize];
for x_ch in 0..x_channels {
let dx_offset = (bi * dx_batch_size + x_ch * dx_img_size) as usize;
let dx_img = &mut dx[dx_offset..dx_offset + dx_img_size as usize];
let w_offset = (y_ch * w_img_size) as usize;
let w = &w[w_offset..w_offset + w_img_size as usize];
full_xcorr2d_5x5(dx_img, dy_img, w, 1.0, y_rows, y_cols, s_row, s_col);
}
}
}
}

View File

@ -1,3 +1,10 @@
mod kernel_3x3;
mod kernel_5x5;
pub use self::kernel_3x3::*;
pub use self::kernel_5x5::*;
#[allow(dead_code)]
pub fn valid_conv2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
x_rows: isize, x_cols: isize,
@ -128,6 +135,7 @@ pub fn full_xcorr2d(y: &mut [f32], x: &[f32], w: &[f32], alpha: f32,
}
}
pub fn conv2d_forward(y: &mut [f32], x: &[f32], w: &[f32],
bs: isize, x_channels: isize, y_channels: isize,
x_rows: isize, x_cols: isize,

View File

@ -643,12 +643,27 @@ impl BackendConv2d<f32> for Native {
let _padding = conv_info.padding;
self.fill_scalar(y, 0.0);
conv2d_forward(
y.write(), x.read(), w.read(),
batch_size, x_channels, y_channels,
x_height, x_width, filter_height, filter_width,
stride_y as isize, stride_x as isize
)
if filter_height == 3 && filter_width == 3 {
conv2d_forward_3x3(
y.write(), x.read(), w.read(),
batch_size, x_channels, y_channels,
x_height, x_width, stride_y as isize, stride_x as isize
)
} else if filter_height == 5 && filter_width == 5 {
conv2d_forward_5x5(
y.write(), x.read(), w.read(),
batch_size, x_channels, y_channels,
x_height, x_width, stride_y as isize, stride_x as isize
)
} else {
conv2d_forward(
y.write(), x.read(), w.read(),
batch_size, x_channels, y_channels,
x_height, x_width, filter_height, filter_width,
stride_y as isize, stride_x as isize
)
}
}
fn conv2d_backward_input(&self, dx: &mut Self::Tensor, dy: &Self::Tensor, w: &Self::Tensor, conv_info: &Conv2dInfo) {
@ -673,13 +688,29 @@ impl BackendConv2d<f32> for Native {
self.fill_scalar(dx, 0.0);
conv2d_backward(
dx.write(), dy.read(), w.read(),
batch_size, dx_channels, dy_channels,
dy_height, dy_width,
filter_height, filter_width,
stride_y as isize, stride_x as isize
)
if filter_height == 3 && filter_width == 3 {
conv2d_backward_3x3(
dx.write(), dy.read(), w.read(),
batch_size, dx_channels, dy_channels,
dy_height, dy_width,
stride_y as isize, stride_x as isize
)
} else if filter_height == 5 && filter_width == 5 {
conv2d_backward_5x5(
dx.write(), dy.read(), w.read(),
batch_size, dx_channels, dy_channels,
dy_height, dy_width,
stride_y as isize, stride_x as isize
)
} else {
conv2d_backward(
dx.write(), dy.read(), w.read(),
batch_size, dx_channels, dy_channels,
dy_height, dy_width,
filter_height, filter_width,
stride_y as isize, stride_x as isize
)
}
}
fn conv2d_backward_filter(&self, dw: &mut Self::Tensor, x: &Self::Tensor, dy: &Self::Tensor, conv_info: &Conv2dInfo) {
@ -815,8 +846,52 @@ impl BackendMaxPool2d<f32> for Native {
}
impl BackendAvgPool2d<f32> for Native {
fn avg_pool2d(&self, _y: &mut Self::Tensor, _x: &Self::Tensor, _conv_info: &Conv2dInfo) {
unimplemented!()
fn avg_pool2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, conv_info: &Conv2dInfo) {
let x_shape = &x.shape().as_slice()[0..4];
let y_shape = &y.shape().as_slice()[0..4];
assert_eq!(x_shape[0], y_shape[0]);
assert_eq!(x_shape[1], y_shape[1]);
let (stride_y, stride_x) = conv_info.strides;
let (stride_y, stride_x) = (stride_y as isize, stride_x as isize);
let (pool_y, pool_x) = conv_info.kernel;
let (pool_y, pool_x) = (pool_y as isize, pool_x as isize);
let batch_size = x_shape[0] as isize;
let channels = x_shape[1] as isize;
let x_rows = x_shape[2] as isize;
let x_cols = x_shape[3] as isize;
let y_rows = (x_rows - pool_y) / stride_y + 1;
let y_cols = (x_cols - pool_x) / stride_x + 1;
assert_eq!(y_rows, y_shape[2] as isize);
assert_eq!(y_cols, y_shape[3] as isize);
let x_img_size = x_rows * x_cols;
let x_batch_size = x_img_size * channels;
let y_img_size = y_rows * y_cols;
let y_batch_size = y_img_size * channels;
let x_vals = &x.read()[0..(batch_size * channels * x_img_size) as usize];
let y_vals = &mut y.write()[0..(batch_size * channels * y_img_size) as usize];
for bi in 0..batch_size {
for ch in 0..channels {
let x_offset = (bi * x_batch_size + ch * x_img_size) as usize;
let x_img = &x_vals[x_offset..x_offset + x_img_size as usize];
let y_offset = (bi * y_batch_size + ch * y_img_size) as usize;
let y_img = &mut y_vals[y_offset..y_offset + y_img_size as usize];
avgpool2d(y_img, x_img, y_rows, y_cols, x_rows, x_cols,
pool_y, pool_x, stride_y, stride_x);
}
}
}
fn avg_pool2d_backprop(&self, _dx: &mut Self::Tensor, _dy: &Self::Tensor, _x: &Self::Tensor, _conv_info: &Conv2dInfo) {
@ -824,11 +899,118 @@ impl BackendAvgPool2d<f32> for Native {
}
}
impl BackendPaddingCopy2d<f32> for Native {
fn copy_with_padding2d(&self, y: &mut Self::Tensor, x: &Self::Tensor, y_paddings: (u32, u32), x_paddings: (u32, u32)) {
let y_shape = &y.shape().as_slice()[0..4];
let x_shape = &x.shape().as_slice()[0..4];
let y_batch_size = y_shape[0] as usize;
let y_filters = y_shape[1] as usize;
let y_rows = y_shape[2] as usize;
let y_cols = y_shape[3] as usize;
let x_batch_size = x_shape[0] as usize;
let x_filters = x_shape[1] as usize;
let x_rows = x_shape[2] as usize;
let x_cols = x_shape[3] as usize;
assert_eq!(y_batch_size, x_batch_size);
assert_eq!(y_filters, x_filters);
let y_filter_stride = y_rows * y_cols;
let y_batch_stride = y_filters * y_filter_stride;
let x_filter_stride = x_rows * x_cols;
let x_batch_stride = x_filters * x_filter_stride;
let y_size = y_batch_size * y_filters * y_rows * y_cols;
let x_size = x_batch_size * x_filters * x_rows * x_cols;
let y_s = &mut y.write()[0 .. y_size];
let x_s = &x.read()[0 .. x_size];
for batch in 0 .. y_batch_size {
for filter in 0..y_filters {
for y_row in 0..y_rows {
for y_col in 0..y_cols {
if y_row < y_paddings.0 as usize ||
y_col < y_paddings.1 as usize {
continue;
}
if y_row - y_paddings.0 as usize >= x_rows ||
y_col - y_paddings.1 as usize >= x_cols {
continue;
}
let x_row = y_row - y_paddings.0 as usize + x_paddings.0 as usize;
let x_col = y_col - y_paddings.1 as usize + x_paddings.1 as usize;
println!("{} {}, {} {}", y_row, y_col, x_row, x_col);
let y_idx = batch * y_batch_stride + filter * y_filter_stride + y_row * y_cols + y_col;
let x_idx = batch * x_batch_stride + filter * x_filter_stride + x_row * x_cols + x_col;
y_s[y_idx] = x_s[x_idx];
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::backend::*;
use super::{Native, NativeTensorF32};
use crate::tensor::{Tensor, TensorShape};
use crate::tensor::Tensor;
#[test]
fn test_copy_with_padding2d() {
let bac = Native;
let mut a1 = NativeTensorF32::new((1, 1, 3, 3));
let mut b1 = NativeTensorF32::new((1, 1, 5, 5));
let mut a2 = NativeTensorF32::new((1, 1, 5, 5));
let mut b2 = NativeTensorF32::new((1, 1, 3, 3));
bac.load_tensor_u8(&mut a1, &[
1, 2, 3,
4, 5, 6,
7, 8, 9,
]);
bac.load_tensor_u8(&mut a2, &[
1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
11, 12, 13, 14, 15,
16, 17, 18, 19, 20,
21, 22, 23, 24, 25,
]);
bac.copy_with_padding2d(&mut b1, &a1, (1, 1), (0, 0));
bac.copy_with_padding2d(&mut b2, &a2, (0, 0), (1, 1));
assert!(
b1.read() == &[
0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 1.0, 2.0, 3.0, 0.0,
0.0, 4.0, 5.0, 6.0, 0.0,
0.0, 7.0, 8.0, 9.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0,
]
);
assert!(
b2.read() == &[
7.0, 8.0, 9.0,
12.0, 13.0, 14.0,
17.0, 18.0, 19.0,
]
);
}
#[test]
fn test_softmax() {
@ -987,227 +1169,4 @@ mod tests {
a.read() == &[2.0, 4.0, 6.0, 8.0]
);
}
// #[test]
// fn test_conv2d() {
// let bac = Native;
// let mut x = NativeTensorF32::new((1, 12, 12, 1));
// let mut y = NativeTensorF32::new((1, 12, 12, 2));
// let mut f = NativeTensorF32::new((3, 3, 2));
// bac.load_tensor_u8(&mut x, &[
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
// 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
// 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
// 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
// 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
// 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
// 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
// 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
// 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
// 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
// ]);
// bac.load_tensor_f32(&mut f, &[
// 0.5, 0.1, 0.6, 0.3, 0.5, 0.8,
// 0.6, 0.4, 0.5, 0.9, 0.6, 0.4,
// 0.5, 0.8, 0.6, 0.3, 0.5, 0.1,
// ]);
// let info = Conv2dInfo {
// kernel: (3, 3),
// padding: PaddingKind::Valid,
// strides: (1, 1),
// };
// bac.conv2d_forward(&mut y, &x, &f, &info);
// assert!(
// slice_eq(y.read(), &[
// 14.3, 5.3, 22.5, 16.6, 25.8, 19.5, 29.1, 22.4, 32.4, 25.3, 35.7, 28.2,
// 39.0, 31.1, 42.3, 34.0, 45.6, 36.9, 48.9, 39.8, 52.2, 42.7, 36.3, 38.4,
// 41.2, 26.5, 63.7, 53.3, 68.6, 57.4, 73.5, 61.5, 78.4, 65.6, 83.3, 69.7,
// 88.2, 73.8, 93.1, 77.9, 98.0, 82.0, 102.9, 86.1, 107.8, 90.2, 74.3, 71.5,
// 80.8, 60.1, 122.5, 102.5, 127.4, 106.6, 132.3, 110.7, 137.2, 114.8, 142.1, 118.9,
// 147.0, 123.0, 151.9, 127.1, 156.8, 131.2, 161.7, 135.3, 166.6, 139.4, 113.9, 105.1,
// 120.4, 93.7, 181.3, 151.7, 186.2, 155.8, 191.1, 159.9, 196.0, 164.0, 200.9, 168.1,
// 205.8, 172.2, 210.7, 176.3, 215.6, 180.4, 220.5, 184.5, 225.4, 188.6, 153.5, 138.7,
// 160.0, 127.3, 240.1, 200.9, 245.0, 205.0, 249.9, 209.1, 254.8, 213.2, 259.7, 217.3,
// 264.6, 221.4, 269.5, 225.5, 274.4, 229.6, 279.3, 233.7, 284.2, 237.8, 193.1, 172.3,
// 199.6, 160.9, 298.9, 250.1, 303.8, 254.2, 308.7, 258.3, 313.6, 262.4, 318.5, 266.5,
// 323.4, 270.6, 328.3, 274.7, 333.2, 278.8, 338.1, 282.9, 343.0, 287.0, 232.7, 205.9,
// 239.2, 194.5, 357.7, 299.3, 362.6, 303.4, 367.5, 307.5, 372.4, 311.6, 377.3, 315.7,
// 382.2, 319.8, 387.1, 323.9, 392.0, 328.0, 396.9, 332.1, 401.8, 336.2, 272.3, 239.5,
// 278.8, 228.1, 416.5, 348.5, 421.4, 352.6, 426.3, 356.7, 431.2, 360.8, 436.1, 364.9,
// 441.0, 369.0, 445.9, 373.1, 450.8, 377.2, 455.7, 381.3, 460.6, 385.4, 311.9, 273.1,
// 318.4, 261.7, 475.3, 397.7, 480.2, 401.8, 485.1, 405.9, 490.0, 410.0, 494.9, 414.1,
// 499.8, 418.2, 504.7, 422.3, 509.6, 426.4, 514.5, 430.5, 519.4, 434.6, 351.5, 306.7,
// 358.0, 295.3, 534.1, 446.9, 539.0, 451.0, 543.9, 455.1, 548.8, 459.2, 553.7, 463.3,
// 558.6, 467.4, 563.5, 471.5, 568.4, 475.6, 573.3, 479.7, 578.2, 483.8, 391.1, 340.3,
// 397.6, 328.9, 592.9, 496.1, 597.8, 500.2, 602.7, 504.3, 607.6, 508.4, 612.5, 512.5,
// 617.4, 516.6, 622.3, 520.7, 627.2, 524.8, 632.1, 528.9, 637.0, 533.0, 430.7, 373.9,
// 278.3, 304.8, 419.7, 372.0, 423.0, 374.9, 426.3, 377.8, 429.6, 380.7, 432.9, 383.6,
// 436.2, 386.5, 439.5, 389.4, 442.8, 392.3, 446.1, 395.2, 449.4, 398.1, 300.3, 237.8
// ])
// );
// }
// #[test]
// fn test_conv2d_backward_input() {
// let bac = Native;
// let mut dx = NativeTensorF32::new((1, 12, 12, 1));
// let mut dy = NativeTensorF32::new((1, 12, 12, 2));
// let mut f = NativeTensorF32::new((3, 3, 2));
// bac.load_tensor_f32(&mut dy, &[
// 0.0, 0.0, 1.0, 0.1, 2.0, 0.2, 3.0, 0.3, 4.0, 0.4, 5.0, 0.5, 6.0, 0.6, 7.0, 0.7, 8.0, 0.8, 9.0, 0.9, 10.0, 1.0, 11.0, 1.1,
// 12.0, 1.2, 13.0, 1.3, 14.0, 1.4, 15.0, 1.5, 16.0, 1.6, 17.0, 1.7, 18.0, 1.8, 19.0, 1.9, 20.0, 2.0, 21.0, 2.1, 22.0, 2.2, 23.0, 2.3,
// 24.0, 2.4, 25.0, 2.5, 26.0, 2.6, 27.0, 2.7, 28.0, 2.8, 29.0, 2.9, 30.0, 3.0, 31.0, 3.1, 32.0, 3.2, 33.0, 3.3, 34.0, 3.4, 35.0, 3.5,
// 36.0, 3.6, 37.0, 3.7, 38.0, 3.8, 39.0, 3.9, 40.0, 4.0, 41.0, 4.1, 42.0, 4.2, 43.0, 4.3, 44.0, 4.4, 45.0, 4.5, 46.0, 4.6, 47.0, 4.7,
// 48.0, 4.8, 49.0, 4.9, 50.0, 5.0, 51.0, 5.1, 52.0, 5.2, 53.0, 5.3, 54.0, 5.4, 55.0, 5.5, 56.0, 5.6, 57.0, 5.7, 58.0, 5.8, 59.0, 5.9,
// 60.0, 6.0, 61.0, 6.1, 62.0, 6.2, 63.0, 6.3, 64.0, 6.4, 65.0, 6.5, 66.0, 6.6, 67.0, 6.7, 68.0, 6.8, 69.0, 6.9, 70.0, 7.0, 71.0, 7.1,
// 72.0, 7.2, 73.0, 7.3, 74.0, 7.4, 75.0, 7.5, 76.0, 7.6, 77.0, 7.7, 78.0, 7.8, 79.0, 7.9, 80.0, 8.0, 81.0, 8.1, 82.0, 8.2, 83.0, 8.3,
// 84.0, 8.4, 85.0, 8.5, 86.0, 8.6, 87.0, 8.7, 88.0, 8.8, 89.0, 8.9, 90.0, 9.0, 91.0, 9.1, 92.0, 9.2, 93.0, 9.3, 94.0, 9.4, 95.0, 9.5,
// 96.0, 9.6, 97.0, 9.7, 98.0, 9.8, 99.0, 9.9, 100.0, 10.0, 101.0, 10.1, 102.0, 10.2, 103.0, 10.3, 104.0, 10.4, 105.0, 10.5, 106.0, 10.6, 107.0, 10.7,
// 108.0, 10.8, 109.0, 10.9, 110.0, 11.0, 111.0, 11.1, 112.0, 11.2, 113.0, 11.3, 114.0, 11.4, 115.0, 11.5, 116.0, 11.6, 117.0, 11.7, 118.0, 11.8, 119.0, 11.9,
// 120.0, 12.0, 121.0, 12.1, 122.0, 12.2, 123.0, 12.3, 124.0, 12.4, 125.0, 12.5, 126.0, 12.6, 127.0, 12.7, 128.0, 12.8, 129.0, 12.9, 130.0, 13.0, 131.0, 13.1,
// 132.0, 13.2, 133.0, 13.3, 134.0, 13.4, 135.0, 13.5, 136.0, 13.6, 137.0, 13.7, 138.0, 13.8, 139.0, 13.9, 140.0, 14.0, 141.0, 14.1, 142.0, 14.2, 143.0, 14.3,
// ]);
// bac.load_tensor_f32(&mut f, &[
// 0.5, 0.1, 0.6, 0.3, 0.5, 0.8,
// 0.6, 0.4, 0.5, 0.9, 0.6, 0.4,
// 0.5, 0.8, 0.6, 0.3, 0.5, 0.1,
// ]);
// let info = Conv2dInfo {
// padding: PaddingKind::Valid,
// strides: (1, 1),
// kernel: (2, 2)
// };
// bac.conv2d_backward_input(&mut dx, &dy, &f, &info);
// assert!(
// slice_eq(dx.read(), &[
// 14.83, 24.16, 27.75, 31.34, 34.93, 38.52, 42.11, 45.70, 49.29, 52.88, 56.47, 40.14,
// 43.85, 69.03, 74.34, 79.65, 84.96, 90.27, 95.58, 100.89, 106.20, 111.51, 116.82, 81.45,
// 86.81, 132.75, 138.06, 143.37, 148.68, 153.99, 159.30, 164.61, 169.92, 175.23, 180.54, 124.41,
// 129.77, 196.47, 201.78, 207.09, 212.4, 217.71, 223.02, 228.33, 233.64, 238.95, 244.26, 167.37,
// 172.73, 260.19, 265.5, 270.81, 276.12, 281.43, 286.74, 292.05, 297.36, 302.67, 307.98, 210.33,
// 215.69, 323.91, 329.22, 334.53, 339.84, 345.15, 350.46, 355.77, 361.08, 366.39, 371.7, 253.29,
// 258.65, 387.63, 392.94, 398.25, 403.56, 408.87, 414.18, 419.49, 424.80, 430.11, 435.42, 296.25,
// 301.61, 451.35, 456.66, 461.97, 467.28, 472.59, 477.90, 483.21, 488.52, 493.83, 499.14, 339.21,
// 344.57, 515.07, 520.38, 525.69, 531.00, 536.31, 541.62, 546.93, 552.24, 557.55, 562.86, 382.17,
// 387.53, 578.79, 584.1, 589.41, 594.72, 600.03, 605.3401, 610.65, 615.96, 621.27, 626.58, 425.13,
// 430.49, 642.51, 647.82, 653.13, 658.44, 663.75, 669.06, 674.37, 679.68, 684.99, 690.30, 468.09,
// 308.78, 456.9, 460.49, 464.08, 467.67, 471.26, 474.85, 478.44, 482.03, 485.62, 489.21, 324.08
// ])
// );
// }
// #[test]
// fn test_maxpool2d() {
// let bac = Native;
// let mut x = NativeTensorF32::new((1, 12, 12, 1));
// let mut y = NativeTensorF32::new((1, 6, 6, 1));
// bac.load_tensor_u8(&mut x, &[
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
// 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
// 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
// 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
// 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
// 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
// 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
// 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
// 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
// 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
// ]);
// let info = Conv2dInfo {
// padding: PaddingKind::Valid,
// strides: (2, 2),
// kernel: (2, 2)
// };
// bac.max_pool2d(&mut y, &x, &info);
// assert!(
// slice_eq(y.read(), &[
// 13.0, 15.0, 17.0, 19.0, 21.0, 23.0,
// 37.0, 39.0, 41.0, 43.0, 45.0, 47.0,
// 61.0, 63.0, 65.0, 67.0, 69.0, 71.0,
// 85.0, 87.0, 89.0, 91.0, 93.0, 95.0,
// 109.0, 111.0, 113.0, 115.0, 117.0, 119.0,
// 133.0, 135.0, 137.0, 139.0, 141.0, 143.0,
// ])
// );
// }
fn slice_eq(a: &[f32], b: &[f32]) -> bool {
if a.len() != b.len() {
return false;
}
let size = a.len();
for i in 0..size {
if (a[i] - b[i]).abs() > 0.0001 {
return false;
}
}
true
}
fn print_tensor(t: &NativeTensorF32, override_strides: Option<&[u32]>) {
let tmp = t.shape.default_strides();
let strides = if let Some(strides) = override_strides {
strides
} else {
tmp.as_slice()
};
let last_idx = strides.len() - 1;
println!("default stridses {} {}", t.shape.default_strides(), last_idx);
print!("Tensor(shape={}, data=[", t.shape);
for (idx, val) in t.read().iter().enumerate() {
let is_first = idx == 0;
let mut need_nl = false;
let padding = 2;
for (sidx, &s) in strides.iter().enumerate() {
if sidx != last_idx && idx % s as usize == 0 {
need_nl = true;
}
}
if !is_first {
print!(", ");
}
if need_nl {
print!("\n{}", " ".repeat(padding));
}
print!("{}", val);
}
print!("\n])");
}
}

View File

@ -63,13 +63,12 @@ pub fn maxpool2d_backward(dx: &mut [f32], x: &[f32], dy: &[f32],
#[allow(dead_code)]
pub fn avgpool2d(y: &mut [f32], x: &[f32],
y_rows: isize, y_cols: isize,
x_rows: isize, x_cols: isize,
w_rows: isize, w_cols: isize,
s_row: isize, s_col: isize) {
let w_size = w_rows * w_cols;
let y_rows = (x_rows - w_rows) / s_row + 1;
let y_cols = (x_cols - w_cols) / s_col + 1;
let y = &mut y[0..(y_rows * y_cols) as usize];
let x = &x[0..(x_rows * x_cols) as usize];
@ -197,4 +196,30 @@ mod tests {
assert_eq!(dx, tt);
}
#[test]
fn test_avgpool2d() {
let x: &[f32] = &[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0, 33.0, 34.0, 35.0, 36.0,
];
let y: &mut [f32] = &mut [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
];
avgpool2d(y, x, 3, 3, 6, 6, 2, 2, 2, 2);
assert_eq!(y, &[
4.5, 6.5, 8.5,
16.5, 18.5, 20.5,
28.5, 30.5, 32.5,
])
}
}