Initial commit
This commit is contained in:
commit
197c1587af
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.undodir
|
||||||
|
.gitkeep
|
||||||
|
/**/*.weights
|
||||||
|
/models/*.onnx
|
||||||
|
/target
|
||||||
|
__pycache__
|
799
Cargo.lock
generated
Normal file
799
Cargo.lock
generated
Normal file
@ -0,0 +1,799 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "0.7.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ansi_term"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
|
||||||
|
dependencies = [
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "approx"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "072df7202e63b127ab55acfe16ce97013d5b97bf160489336d3f1840fd78e99e"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atty"
|
||||||
|
version = "0.2.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.56.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2da379dbebc0b76ef63ca68d8fc6e71c0f13e59432e0987e508c1820e6ab5239"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"clap",
|
||||||
|
"env_logger",
|
||||||
|
"lazy_static",
|
||||||
|
"lazycell",
|
||||||
|
"log",
|
||||||
|
"peeking_take_while",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"which",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "1.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytemuck"
|
||||||
|
version = "1.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cexpr"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clang-sys"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fa66045b9cb23c2e9c1520732030608b02ee07e5cfaa5a521ec15ded7fa24c90"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"libc",
|
||||||
|
"libloading",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap"
|
||||||
|
version = "2.33.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002"
|
||||||
|
dependencies = [
|
||||||
|
"ansi_term",
|
||||||
|
"atty",
|
||||||
|
"bitflags",
|
||||||
|
"strsim",
|
||||||
|
"textwrap",
|
||||||
|
"unicode-width",
|
||||||
|
"vec_map",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "69323bff1fb41c635347b8ead484a5ca6c3f11914d784170b158d8449ab07f8e"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 0.1.10",
|
||||||
|
"crossbeam-channel",
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-queue",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
"maybe-uninit",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c20ff29ded3204c5106278a81a38f4b482636ed4fa1e6cfbeef193291beb29ed"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"maybe-uninit",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"cfg-if 0.1.10",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"lazy_static",
|
||||||
|
"maybe-uninit",
|
||||||
|
"memoffset",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-queue"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "774ba60a54c213d409d5353bda12d49cd68d14e45036a285234c8d6f91f92570"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 0.1.10",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"maybe-uninit",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"cfg-if 0.1.10",
|
||||||
|
"lazy_static",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||||
|
dependencies = [
|
||||||
|
"atty",
|
||||||
|
"humantime",
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
"termcolor",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fixedbitset"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "45e780567ed7abc415d12fd464571d265eb4a5710ddc97cdb1a31a4c35bb479d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "heck"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-segmentation",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.1.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy_static"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazycell"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.108"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8521a1b57e76b1ec69af7599e75e38e7b7fad6610f037db8c79b127201b5d119"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "afe203d669ec979b7128619bae5a63b7b42e9203c1b29146079ee05e2f604b52"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "916806ba0031cd542105d916a97c8572e1fa6dd79c9c51e7eb43a09ec2dd84c1"
|
||||||
|
dependencies = [
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
|
||||||
|
dependencies = [
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "maybe-uninit"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memchr"
|
||||||
|
version = "2.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.5.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "munkres"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "git+https://github.com/andreytkachenko/munkres-rs#f989a06df80f30c79d71540bec1e54cf5fcf9e69"
|
||||||
|
dependencies = [
|
||||||
|
"fixedbitset",
|
||||||
|
"ndarray 0.14.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nalgebra"
|
||||||
|
version = "0.29.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"matrixmultiply 0.3.2",
|
||||||
|
"nalgebra-macros",
|
||||||
|
"num-complex 0.4.0",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
"simba",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nalgebra-macros"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c0d5c9540a691d153064dc47a4db2504587a75eae07bf1d73f7a596ebc73c04"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply 0.2.4",
|
||||||
|
"num-complex 0.3.1",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.15.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply 0.3.2",
|
||||||
|
"num-complex 0.4.0",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "5.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "747d632c0c558b87dbabbe6a82f3b4ae03720d0646ac5b7b4dae89394be5f2c5"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.44"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "onnx-model"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "git+https://github.com/andreytkachenko/onnx-model?branch=v1.8#b1317664cdf34a4404f14283ef5beac8f03c4280"
|
||||||
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
|
"ndarray 0.15.3",
|
||||||
|
"onnxruntime",
|
||||||
|
"smallstr",
|
||||||
|
"smallvec",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "onnxruntime"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://github.com/andreytkachenko/onnxruntime-rs.git?branch=v1.8#d4b8f290d51a6068edf915e357db1839a3a59d4a"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen",
|
||||||
|
"crossbeam",
|
||||||
|
"heck",
|
||||||
|
"lazy_static",
|
||||||
|
"structopt",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0744126afe1a6dd7f394cb50a716dbe086cb06e255e53d8d0185d82828358fb5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "peeking_take_while"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error-attr",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error-attr"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-xid",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "qtrack"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"munkres",
|
||||||
|
"nalgebra",
|
||||||
|
"ndarray 0.15.3",
|
||||||
|
"num-traits",
|
||||||
|
"onnx-model",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.5.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.6.25"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc-hash"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "safe_arch"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde"
|
||||||
|
version = "1.0.130"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_derive"
|
||||||
|
version = "1.0.130"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shlex"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simba"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex 0.4.0",
|
||||||
|
"num-traits",
|
||||||
|
"paste",
|
||||||
|
"wide",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smallstr"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e922794d168678729ffc7e07182721a14219c65814e66e91b839a272fe5ae4f"
|
||||||
|
dependencies = [
|
||||||
|
"smallvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smallvec"
|
||||||
|
version = "1.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strsim"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "structopt"
|
||||||
|
version = "0.3.25"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "40b9788f4202aa75c240ecc9c15c65185e6a39ccdeb0fd5d008b98825464c87c"
|
||||||
|
dependencies = [
|
||||||
|
"clap",
|
||||||
|
"lazy_static",
|
||||||
|
"structopt-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "structopt-derive"
|
||||||
|
version = "0.4.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.81"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f2afee18b8beb5a596ecb4a2dce128c719b4ba399d34126b9e4396e3f9860966"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-xid",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "termcolor"
|
||||||
|
version = "1.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "textwrap"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-width",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b63708a265f51345575b27fe43f9500ad611579e764c79edbc2037b1121959ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-segmentation"
|
||||||
|
version = "1.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-width"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-xid"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vec_map"
|
||||||
|
version = "0.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "which"
|
||||||
|
version = "3.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d011071ae14a2f6671d0b74080ae0cd8ebf3a6f8c9589a2cd45f23126fe29724"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wide"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2f2548e954f6619da26c140d020e99e59a2ca872a11f1e6250b829e8c96c893"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"safe_arch",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-util"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
|
||||||
|
dependencies = [
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
14
Cargo.toml
Normal file
14
Cargo.toml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[package]
|
||||||
|
name = "qtrack"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
nalgebra = "0.29"
|
||||||
|
ndarray = "0.15"
|
||||||
|
num-traits = "0.2"
|
||||||
|
serde = "1.0"
|
||||||
|
serde_derive = "1.0"
|
||||||
|
thiserror = "1.0"
|
||||||
|
munkres = { version = "0.5", git = "https://github.com/andreytkachenko/munkres-rs" }
|
||||||
|
onnx-model = { git = "https://github.com/andreytkachenko/onnx-model", branch = "v1.8" }
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2021 Andrey Tkachenko
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
0
models/README.md
Normal file
0
models/README.md
Normal file
2
scripts/README.md
Normal file
2
scripts/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
donwload weights
|
||||||
|
https://drive.google.com/file/d/1sWNozS0emz7bmQTUWDLvsubLGnCwUiIS/view
|
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
0
scripts/libs/__init__.py
Normal file
0
scripts/libs/__init__.py
Normal file
259
scripts/libs/config.py
Normal file
259
scripts/libs/config.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
from libs.torch_utils import convert2cpu
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cfg(cfgfile):
|
||||||
|
blocks = []
|
||||||
|
fp = open(cfgfile, 'r')
|
||||||
|
block = None
|
||||||
|
line = fp.readline()
|
||||||
|
while line != '':
|
||||||
|
line = line.rstrip()
|
||||||
|
if line == '' or line[0] == '#':
|
||||||
|
line = fp.readline()
|
||||||
|
continue
|
||||||
|
elif line[0] == '[':
|
||||||
|
if block:
|
||||||
|
blocks.append(block)
|
||||||
|
block = dict()
|
||||||
|
block['type'] = line.lstrip('[').rstrip(']')
|
||||||
|
# set default value
|
||||||
|
if block['type'] == 'convolutional':
|
||||||
|
block['batch_normalize'] = 0
|
||||||
|
else:
|
||||||
|
key, value = line.split('=')
|
||||||
|
key = key.strip()
|
||||||
|
if key == 'type':
|
||||||
|
key = '_type'
|
||||||
|
value = value.strip()
|
||||||
|
block[key] = value
|
||||||
|
line = fp.readline()
|
||||||
|
|
||||||
|
if block:
|
||||||
|
blocks.append(block)
|
||||||
|
fp.close()
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def print_cfg(blocks):
|
||||||
|
print('layer filters size input output');
|
||||||
|
prev_width = 416
|
||||||
|
prev_height = 416
|
||||||
|
prev_filters = 3
|
||||||
|
out_filters = []
|
||||||
|
out_widths = []
|
||||||
|
out_heights = []
|
||||||
|
ind = -2
|
||||||
|
for block in blocks:
|
||||||
|
ind = ind + 1
|
||||||
|
if block['type'] == 'net':
|
||||||
|
prev_width = int(block['width'])
|
||||||
|
prev_height = int(block['height'])
|
||||||
|
continue
|
||||||
|
elif block['type'] == 'convolutional':
|
||||||
|
filters = int(block['filters'])
|
||||||
|
kernel_size = int(block['size'])
|
||||||
|
stride = int(block['stride'])
|
||||||
|
is_pad = int(block['pad'])
|
||||||
|
pad = (kernel_size - 1) // 2 if is_pad else 0
|
||||||
|
width = (prev_width + 2 * pad - kernel_size) // stride + 1
|
||||||
|
height = (prev_height + 2 * pad - kernel_size) // stride + 1
|
||||||
|
print('%5d %-6s %4d %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (
|
||||||
|
ind, 'conv', filters, kernel_size, kernel_size, stride, prev_width, prev_height, prev_filters, width,
|
||||||
|
height, filters))
|
||||||
|
prev_width = width
|
||||||
|
prev_height = height
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'maxpool':
|
||||||
|
pool_size = int(block['size'])
|
||||||
|
stride = int(block['stride'])
|
||||||
|
width = prev_width // stride
|
||||||
|
height = prev_height // stride
|
||||||
|
print('%5d %-6s %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (
|
||||||
|
ind, 'max', pool_size, pool_size, stride, prev_width, prev_height, prev_filters, width, height,
|
||||||
|
filters))
|
||||||
|
prev_width = width
|
||||||
|
prev_height = height
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'avgpool':
|
||||||
|
width = 1
|
||||||
|
height = 1
|
||||||
|
print('%5d %-6s %3d x %3d x%4d -> %3d' % (
|
||||||
|
ind, 'avg', prev_width, prev_height, prev_filters, prev_filters))
|
||||||
|
prev_width = width
|
||||||
|
prev_height = height
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'softmax':
|
||||||
|
print('%5d %-6s -> %3d' % (ind, 'softmax', prev_filters))
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'cost':
|
||||||
|
print('%5d %-6s -> %3d' % (ind, 'cost', prev_filters))
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'reorg':
|
||||||
|
stride = int(block['stride'])
|
||||||
|
filters = stride * stride * prev_filters
|
||||||
|
width = prev_width // stride
|
||||||
|
height = prev_height // stride
|
||||||
|
print('%5d %-6s / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (
|
||||||
|
ind, 'reorg', stride, prev_width, prev_height, prev_filters, width, height, filters))
|
||||||
|
prev_width = width
|
||||||
|
prev_height = height
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'upsample':
|
||||||
|
stride = int(block['stride'])
|
||||||
|
filters = prev_filters
|
||||||
|
width = prev_width * stride
|
||||||
|
height = prev_height * stride
|
||||||
|
print('%5d %-6s * %d %3d x %3d x%4d -> %3d x %3d x%4d' % (
|
||||||
|
ind, 'upsample', stride, prev_width, prev_height, prev_filters, width, height, filters))
|
||||||
|
prev_width = width
|
||||||
|
prev_height = height
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'route':
|
||||||
|
layers = block['layers'].split(',')
|
||||||
|
layers = [int(i) if int(i) > 0 else int(i) + ind for i in layers]
|
||||||
|
if len(layers) == 1:
|
||||||
|
print('%5d %-6s %d' % (ind, 'route', layers[0]))
|
||||||
|
prev_width = out_widths[layers[0]]
|
||||||
|
prev_height = out_heights[layers[0]]
|
||||||
|
prev_filters = out_filters[layers[0]]
|
||||||
|
elif len(layers) == 2:
|
||||||
|
print('%5d %-6s %d %d' % (ind, 'route', layers[0], layers[1]))
|
||||||
|
prev_width = out_widths[layers[0]]
|
||||||
|
prev_height = out_heights[layers[0]]
|
||||||
|
assert (prev_width == out_widths[layers[1]])
|
||||||
|
assert (prev_height == out_heights[layers[1]])
|
||||||
|
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
|
||||||
|
elif len(layers) == 4:
|
||||||
|
print('%5d %-6s %d %d %d %d' % (ind, 'route', layers[0], layers[1], layers[2], layers[3]))
|
||||||
|
prev_width = out_widths[layers[0]]
|
||||||
|
prev_height = out_heights[layers[0]]
|
||||||
|
assert (prev_width == out_widths[layers[1]] == out_widths[layers[2]] == out_widths[layers[3]])
|
||||||
|
assert (prev_height == out_heights[layers[1]] == out_heights[layers[2]] == out_heights[layers[3]])
|
||||||
|
prev_filters = out_filters[layers[0]] + out_filters[layers[1]] + out_filters[layers[2]] + out_filters[
|
||||||
|
layers[3]]
|
||||||
|
else:
|
||||||
|
print("route error !!! {} {} {}".format(sys._getframe().f_code.co_filename,
|
||||||
|
sys._getframe().f_code.co_name, sys._getframe().f_lineno))
|
||||||
|
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] in ['region', 'yolo']:
|
||||||
|
print('%5d %-6s' % (ind, 'detection'))
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'shortcut':
|
||||||
|
from_id = int(block['from'])
|
||||||
|
from_id = from_id if from_id > 0 else from_id + ind
|
||||||
|
print('%5d %-6s %d' % (ind, 'shortcut', from_id))
|
||||||
|
prev_width = out_widths[from_id]
|
||||||
|
prev_height = out_heights[from_id]
|
||||||
|
prev_filters = out_filters[from_id]
|
||||||
|
out_widths.append(prev_width)
|
||||||
|
out_heights.append(prev_height)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
elif block['type'] == 'connected':
|
||||||
|
filters = int(block['output'])
|
||||||
|
print('%5d %-6s %d -> %3d' % (ind, 'connected', prev_filters, filters))
|
||||||
|
prev_filters = filters
|
||||||
|
out_widths.append(1)
|
||||||
|
out_heights.append(1)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
else:
|
||||||
|
print('unknown type %s' % (block['type']))
|
||||||
|
|
||||||
|
|
||||||
|
def load_conv(buf, start, conv_model):
|
||||||
|
num_w = conv_model.weight.numel()
|
||||||
|
num_b = conv_model.bias.numel()
|
||||||
|
conv_model.bias.data.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_w]).reshape(conv_model.weight.data.shape));
|
||||||
|
start = start + num_w
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def save_conv(fp, conv_model):
|
||||||
|
if conv_model.bias.is_cuda:
|
||||||
|
convert2cpu(conv_model.bias.data).numpy().tofile(fp)
|
||||||
|
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
|
||||||
|
else:
|
||||||
|
conv_model.bias.data.numpy().tofile(fp)
|
||||||
|
conv_model.weight.data.numpy().tofile(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def load_conv_bn(buf, start, conv_model, bn_model):
|
||||||
|
num_w = conv_model.weight.numel()
|
||||||
|
num_b = bn_model.bias.numel()
|
||||||
|
bn_model.bias.data.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
bn_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
bn_model.running_mean.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
bn_model.running_var.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_w]).reshape(conv_model.weight.data.shape));
|
||||||
|
start = start + num_w
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def save_conv_bn(fp, conv_model, bn_model):
|
||||||
|
if bn_model.bias.is_cuda:
|
||||||
|
convert2cpu(bn_model.bias.data).numpy().tofile(fp)
|
||||||
|
convert2cpu(bn_model.weight.data).numpy().tofile(fp)
|
||||||
|
convert2cpu(bn_model.running_mean).numpy().tofile(fp)
|
||||||
|
convert2cpu(bn_model.running_var).numpy().tofile(fp)
|
||||||
|
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
|
||||||
|
else:
|
||||||
|
bn_model.bias.data.numpy().tofile(fp)
|
||||||
|
bn_model.weight.data.numpy().tofile(fp)
|
||||||
|
bn_model.running_mean.numpy().tofile(fp)
|
||||||
|
bn_model.running_var.numpy().tofile(fp)
|
||||||
|
conv_model.weight.data.numpy().tofile(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def load_fc(buf, start, fc_model):
|
||||||
|
num_w = fc_model.weight.numel()
|
||||||
|
num_b = fc_model.bias.numel()
|
||||||
|
fc_model.bias.data.copy_(torch.from_numpy(buf[start:start + num_b]));
|
||||||
|
start = start + num_b
|
||||||
|
fc_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_w]));
|
||||||
|
start = start + num_w
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def save_fc(fp, fc_model):
|
||||||
|
fc_model.bias.data.numpy().tofile(fp)
|
||||||
|
fc_model.weight.data.numpy().tofile(fp)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
|
||||||
|
blocks = parse_cfg('cfg/yolo.cfg')
|
||||||
|
if len(sys.argv) == 2:
|
||||||
|
blocks = parse_cfg(sys.argv[1])
|
||||||
|
print_cfg(blocks)
|
516
scripts/libs/models.py
Normal file
516
scripts/libs/models.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from libs.region_loss import RegionLoss
|
||||||
|
from libs.yolo_layer import YoloLayer
|
||||||
|
from libs.config import *
|
||||||
|
from libs.torch_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class Mish(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * (torch.tanh(F.softplus(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPoolDark(nn.Module):
|
||||||
|
def __init__(self, size=2, stride=1):
|
||||||
|
super(MaxPoolDark, self).__init__()
|
||||||
|
self.size = size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
darknet output_size = (input_size + p - k) / s +1
|
||||||
|
p : padding = k - 1
|
||||||
|
k : size
|
||||||
|
s : stride
|
||||||
|
torch output_size = (input_size + 2*p -k) / s +1
|
||||||
|
p : padding = k//2
|
||||||
|
'''
|
||||||
|
p = self.size // 2
|
||||||
|
if ((x.shape[2] - 1) // self.stride) != ((x.shape[2] + 2 * p - self.size) // self.stride):
|
||||||
|
padding1 = (self.size - 1) // 2
|
||||||
|
padding2 = padding1 + 1
|
||||||
|
else:
|
||||||
|
padding1 = (self.size - 1) // 2
|
||||||
|
padding2 = padding1
|
||||||
|
if ((x.shape[3] - 1) // self.stride) != ((x.shape[3] + 2 * p - self.size) // self.stride):
|
||||||
|
padding3 = (self.size - 1) // 2
|
||||||
|
padding4 = padding3 + 1
|
||||||
|
else:
|
||||||
|
padding3 = (self.size - 1) // 2
|
||||||
|
padding4 = padding3
|
||||||
|
x = F.max_pool2d(F.pad(x, (padding3, padding4, padding1, padding2), mode='replicate'),
|
||||||
|
self.size, stride=self.stride)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample_expand(nn.Module):
|
||||||
|
def __init__(self, stride=2):
|
||||||
|
super(Upsample_expand, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert (x.data.dim() == 4)
|
||||||
|
|
||||||
|
x = x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1).\
|
||||||
|
expand(x.size(0), x.size(1), x.size(2), self.stride, x.size(3), self.stride).contiguous().\
|
||||||
|
view(x.size(0), x.size(1), x.size(2) * self.stride, x.size(3) * self.stride)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample_interpolate(nn.Module):
|
||||||
|
def __init__(self, stride):
|
||||||
|
super(Upsample_interpolate, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert (x.data.dim() == 4)
|
||||||
|
|
||||||
|
out = F.interpolate(x, size=(x.size(2) * self.stride, x.size(3) * self.stride), mode='nearest')
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Reorg(nn.Module):
|
||||||
|
def __init__(self, stride=2):
|
||||||
|
super(Reorg, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
stride = self.stride
|
||||||
|
assert (x.data.dim() == 4)
|
||||||
|
B = x.data.size(0)
|
||||||
|
C = x.data.size(1)
|
||||||
|
H = x.data.size(2)
|
||||||
|
W = x.data.size(3)
|
||||||
|
assert (H % stride == 0)
|
||||||
|
assert (W % stride == 0)
|
||||||
|
ws = stride
|
||||||
|
hs = stride
|
||||||
|
x = x.view(B, C, H / hs, hs, W / ws, ws).transpose(3, 4).contiguous()
|
||||||
|
x = x.view(B, C, H / hs * W / ws, hs * ws).transpose(2, 3).contiguous()
|
||||||
|
x = x.view(B, C, hs * ws, H / hs, W / ws).transpose(1, 2).contiguous()
|
||||||
|
x = x.view(B, hs * ws * C, H / hs, W / ws)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalAvgPool2d(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(GlobalAvgPool2d, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N = x.data.size(0)
|
||||||
|
C = x.data.size(1)
|
||||||
|
H = x.data.size(2)
|
||||||
|
W = x.data.size(3)
|
||||||
|
x = F.avg_pool2d(x, (H, W))
|
||||||
|
x = x.view(N, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# for route and shortcut
|
||||||
|
class EmptyModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(EmptyModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# support route shortcut and reorg
|
||||||
|
class Darknet(nn.Module):
|
||||||
|
def __init__(self, cfgfile, inference=False):
|
||||||
|
super(Darknet, self).__init__()
|
||||||
|
self.inference = inference
|
||||||
|
self.training = not self.inference
|
||||||
|
|
||||||
|
self.blocks = parse_cfg(cfgfile)
|
||||||
|
self.width = int(self.blocks[0]['width'])
|
||||||
|
self.height = int(self.blocks[0]['height'])
|
||||||
|
|
||||||
|
self.models = self.create_network(self.blocks) # merge conv, bn,leaky
|
||||||
|
self.loss = self.models[len(self.models) - 1]
|
||||||
|
|
||||||
|
if self.blocks[(len(self.blocks) - 1)]['type'] == 'region':
|
||||||
|
self.anchors = self.loss.anchors
|
||||||
|
self.num_anchors = self.loss.num_anchors
|
||||||
|
self.anchor_step = self.loss.anchor_step
|
||||||
|
self.num_classes = self.loss.num_classes
|
||||||
|
|
||||||
|
self.header = torch.IntTensor([0, 0, 0, 0])
|
||||||
|
self.seen = 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ind = -2
|
||||||
|
self.loss = None
|
||||||
|
outputs = dict()
|
||||||
|
out_boxes = []
|
||||||
|
for block in self.blocks:
|
||||||
|
ind = ind + 1
|
||||||
|
# if ind > 0:
|
||||||
|
# return x
|
||||||
|
|
||||||
|
if block['type'] == 'net':
|
||||||
|
continue
|
||||||
|
elif block['type'] in ['convolutional', 'maxpool', 'reorg', 'upsample', 'avgpool', 'softmax', 'connected']:
|
||||||
|
x = self.models[ind](x)
|
||||||
|
outputs[ind] = x
|
||||||
|
elif block['type'] == 'route':
|
||||||
|
layers = block['layers'].split(',')
|
||||||
|
layers = [int(i) if int(i) > 0 else int(i) + ind for i in layers]
|
||||||
|
if len(layers) == 1:
|
||||||
|
if 'groups' not in block.keys() or int(block['groups']) == 1:
|
||||||
|
x = outputs[layers[0]]
|
||||||
|
outputs[ind] = x
|
||||||
|
else:
|
||||||
|
groups = int(block['groups'])
|
||||||
|
group_id = int(block['group_id'])
|
||||||
|
_, b, _, _ = outputs[layers[0]].shape
|
||||||
|
x = outputs[layers[0]][:, b // groups * group_id:b // groups * (group_id + 1)]
|
||||||
|
outputs[ind] = x
|
||||||
|
elif len(layers) == 2:
|
||||||
|
x1 = outputs[layers[0]]
|
||||||
|
x2 = outputs[layers[1]]
|
||||||
|
x = torch.cat((x1, x2), 1)
|
||||||
|
outputs[ind] = x
|
||||||
|
elif len(layers) == 4:
|
||||||
|
x1 = outputs[layers[0]]
|
||||||
|
x2 = outputs[layers[1]]
|
||||||
|
x3 = outputs[layers[2]]
|
||||||
|
x4 = outputs[layers[3]]
|
||||||
|
x = torch.cat((x1, x2, x3, x4), 1)
|
||||||
|
outputs[ind] = x
|
||||||
|
else:
|
||||||
|
print("rounte number > 2 ,is {}".format(len(layers)))
|
||||||
|
|
||||||
|
elif block['type'] == 'shortcut':
|
||||||
|
from_layer = int(block['from'])
|
||||||
|
activation = block['activation']
|
||||||
|
from_layer = from_layer if from_layer > 0 else from_layer + ind
|
||||||
|
x1 = outputs[from_layer]
|
||||||
|
x2 = outputs[ind - 1]
|
||||||
|
x = x1 + x2
|
||||||
|
if activation == 'leaky':
|
||||||
|
x = F.leaky_relu(x, 0.1, inplace=True)
|
||||||
|
elif activation == 'relu':
|
||||||
|
x = F.relu(x, inplace=True)
|
||||||
|
outputs[ind] = x
|
||||||
|
elif block['type'] == 'region':
|
||||||
|
continue
|
||||||
|
if self.loss:
|
||||||
|
self.loss = self.loss + self.models[ind](x)
|
||||||
|
else:
|
||||||
|
self.loss = self.models[ind](x)
|
||||||
|
outputs[ind] = None
|
||||||
|
elif block['type'] == 'yolo':
|
||||||
|
# if self.training:
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# boxes = self.models[ind](x)
|
||||||
|
# out_boxes.append(boxes)
|
||||||
|
boxes = self.models[ind](x)
|
||||||
|
out_boxes.append(boxes)
|
||||||
|
elif block['type'] == 'cost':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print('unknown type %s' % (block['type']))
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
return out_boxes
|
||||||
|
else:
|
||||||
|
return get_region_boxes(out_boxes)
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
print_cfg(self.blocks)
|
||||||
|
|
||||||
|
def create_network(self, blocks):
|
||||||
|
models = nn.ModuleList()
|
||||||
|
|
||||||
|
prev_filters = 3
|
||||||
|
out_filters = []
|
||||||
|
prev_stride = 1
|
||||||
|
out_strides = []
|
||||||
|
conv_id = 0
|
||||||
|
for block in blocks:
|
||||||
|
if block['type'] == 'net':
|
||||||
|
prev_filters = int(block['channels'])
|
||||||
|
continue
|
||||||
|
elif block['type'] == 'convolutional':
|
||||||
|
conv_id = conv_id + 1
|
||||||
|
batch_normalize = int(block['batch_normalize'])
|
||||||
|
filters = int(block['filters'])
|
||||||
|
kernel_size = int(block['size'])
|
||||||
|
stride = int(block['stride'])
|
||||||
|
is_pad = int(block['pad'])
|
||||||
|
pad = (kernel_size - 1) // 2 if is_pad else 0
|
||||||
|
activation = block['activation']
|
||||||
|
model = nn.Sequential()
|
||||||
|
if batch_normalize:
|
||||||
|
model.add_module('conv{0}'.format(conv_id),
|
||||||
|
nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias=False))
|
||||||
|
model.add_module('bn{0}'.format(conv_id), nn.BatchNorm2d(filters))
|
||||||
|
# model.add_module('bn{0}'.format(conv_id), BN2d(filters))
|
||||||
|
else:
|
||||||
|
model.add_module('conv{0}'.format(conv_id),
|
||||||
|
nn.Conv2d(prev_filters, filters, kernel_size, stride, pad))
|
||||||
|
if activation == 'leaky':
|
||||||
|
model.add_module('leaky{0}'.format(conv_id), nn.LeakyReLU(0.1, inplace=True))
|
||||||
|
elif activation == 'relu':
|
||||||
|
model.add_module('relu{0}'.format(conv_id), nn.ReLU(inplace=True))
|
||||||
|
elif activation == 'mish':
|
||||||
|
model.add_module('mish{0}'.format(conv_id), Mish())
|
||||||
|
else:
|
||||||
|
print("convalution havn't activate {}".format(activation))
|
||||||
|
|
||||||
|
prev_filters = filters
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
prev_stride = stride * prev_stride
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'maxpool':
|
||||||
|
pool_size = int(block['size'])
|
||||||
|
stride = int(block['stride'])
|
||||||
|
if stride == 1 and pool_size % 2:
|
||||||
|
# You can use Maxpooldark instead, here is convenient to convert onnx.
|
||||||
|
# Example: [maxpool] size=3 stride=1
|
||||||
|
model = nn.MaxPool2d(kernel_size=pool_size, stride=stride, padding=pool_size // 2)
|
||||||
|
elif stride == pool_size:
|
||||||
|
# You can use Maxpooldark instead, here is convenient to convert onnx.
|
||||||
|
# Example: [maxpool] size=2 stride=2
|
||||||
|
model = nn.MaxPool2d(kernel_size=pool_size, stride=stride, padding=0)
|
||||||
|
else:
|
||||||
|
model = MaxPoolDark(pool_size, stride)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
prev_stride = stride * prev_stride
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'avgpool':
|
||||||
|
model = GlobalAvgPool2d()
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'softmax':
|
||||||
|
model = nn.Softmax()
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'cost':
|
||||||
|
if block['_type'] == 'sse':
|
||||||
|
model = nn.MSELoss(reduction='mean')
|
||||||
|
elif block['_type'] == 'L1':
|
||||||
|
model = nn.L1Loss(reduction='mean')
|
||||||
|
elif block['_type'] == 'smooth':
|
||||||
|
model = nn.SmoothL1Loss(reduction='mean')
|
||||||
|
out_filters.append(1)
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'reorg':
|
||||||
|
stride = int(block['stride'])
|
||||||
|
prev_filters = stride * stride * prev_filters
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
prev_stride = prev_stride * stride
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(Reorg(stride))
|
||||||
|
elif block['type'] == 'upsample':
|
||||||
|
stride = int(block['stride'])
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
prev_stride = prev_stride // stride
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
|
||||||
|
models.append(Upsample_expand(stride))
|
||||||
|
# models.append(Upsample_interpolate(stride))
|
||||||
|
|
||||||
|
elif block['type'] == 'route':
|
||||||
|
layers = block['layers'].split(',')
|
||||||
|
ind = len(models)
|
||||||
|
layers = [int(i) if int(i) > 0 else int(i) + ind for i in layers]
|
||||||
|
if len(layers) == 1:
|
||||||
|
if 'groups' not in block.keys() or int(block['groups']) == 1:
|
||||||
|
prev_filters = out_filters[layers[0]]
|
||||||
|
prev_stride = out_strides[layers[0]]
|
||||||
|
else:
|
||||||
|
prev_filters = out_filters[layers[0]] // int(block['groups'])
|
||||||
|
prev_stride = out_strides[layers[0]] // int(block['groups'])
|
||||||
|
elif len(layers) == 2:
|
||||||
|
assert (layers[0] == ind - 1 or layers[1] == ind - 1)
|
||||||
|
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
|
||||||
|
prev_stride = out_strides[layers[0]]
|
||||||
|
elif len(layers) == 4:
|
||||||
|
assert (layers[0] == ind - 1)
|
||||||
|
prev_filters = out_filters[layers[0]] + out_filters[layers[1]] + out_filters[layers[2]] + \
|
||||||
|
out_filters[layers[3]]
|
||||||
|
prev_stride = out_strides[layers[0]]
|
||||||
|
else:
|
||||||
|
print("route error!!!")
|
||||||
|
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(EmptyModule())
|
||||||
|
elif block['type'] == 'shortcut':
|
||||||
|
ind = len(models)
|
||||||
|
prev_filters = out_filters[ind - 1]
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
prev_stride = out_strides[ind - 1]
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(EmptyModule())
|
||||||
|
elif block['type'] == 'connected':
|
||||||
|
filters = int(block['output'])
|
||||||
|
if block['activation'] == 'linear':
|
||||||
|
model = nn.Linear(prev_filters, filters)
|
||||||
|
elif block['activation'] == 'leaky':
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(prev_filters, filters),
|
||||||
|
nn.LeakyReLU(0.1, inplace=True))
|
||||||
|
elif block['activation'] == 'relu':
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(prev_filters, filters),
|
||||||
|
nn.ReLU(inplace=True))
|
||||||
|
prev_filters = filters
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(model)
|
||||||
|
elif block['type'] == 'region':
|
||||||
|
loss = RegionLoss()
|
||||||
|
anchors = block['anchors'].split(',')
|
||||||
|
loss.anchors = [float(i) for i in anchors]
|
||||||
|
loss.num_classes = int(block['classes'])
|
||||||
|
loss.num_anchors = int(block['num'])
|
||||||
|
loss.anchor_step = len(loss.anchors) // loss.num_anchors
|
||||||
|
loss.object_scale = float(block['object_scale'])
|
||||||
|
loss.noobject_scale = float(block['noobject_scale'])
|
||||||
|
loss.class_scale = float(block['class_scale'])
|
||||||
|
loss.coord_scale = float(block['coord_scale'])
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(loss)
|
||||||
|
elif block['type'] == 'yolo':
|
||||||
|
yolo_layer = YoloLayer()
|
||||||
|
anchors = block['anchors'].split(',')
|
||||||
|
anchor_mask = block['mask'].split(',')
|
||||||
|
yolo_layer.anchor_mask = [int(i) for i in anchor_mask]
|
||||||
|
yolo_layer.anchors = [float(i) for i in anchors]
|
||||||
|
yolo_layer.num_classes = int(block['classes'])
|
||||||
|
self.num_classes = yolo_layer.num_classes
|
||||||
|
yolo_layer.num_anchors = int(block['num'])
|
||||||
|
yolo_layer.anchor_step = len(yolo_layer.anchors) // yolo_layer.num_anchors
|
||||||
|
yolo_layer.stride = prev_stride
|
||||||
|
yolo_layer.scale_x_y = float(block['scale_x_y'])
|
||||||
|
# yolo_layer.object_scale = float(block['object_scale'])
|
||||||
|
# yolo_layer.noobject_scale = float(block['noobject_scale'])
|
||||||
|
# yolo_layer.class_scale = float(block['class_scale'])
|
||||||
|
# yolo_layer.coord_scale = float(block['coord_scale'])
|
||||||
|
out_filters.append(prev_filters)
|
||||||
|
out_strides.append(prev_stride)
|
||||||
|
models.append(yolo_layer)
|
||||||
|
else:
|
||||||
|
print('unknown type %s' % (block['type']))
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def load_weights(self, weightfile):
|
||||||
|
fp = open(weightfile, 'rb')
|
||||||
|
header = np.fromfile(fp, count=5, dtype=np.int32)
|
||||||
|
self.header = torch.from_numpy(header)
|
||||||
|
self.seen = self.header[3]
|
||||||
|
buf = np.fromfile(fp, dtype=np.float32)
|
||||||
|
fp.close()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
ind = -2
|
||||||
|
for block in self.blocks:
|
||||||
|
if start >= buf.size:
|
||||||
|
break
|
||||||
|
ind = ind + 1
|
||||||
|
if block['type'] == 'net':
|
||||||
|
continue
|
||||||
|
elif block['type'] == 'convolutional':
|
||||||
|
model = self.models[ind]
|
||||||
|
batch_normalize = int(block['batch_normalize'])
|
||||||
|
if batch_normalize:
|
||||||
|
start = load_conv_bn(buf, start, model[0], model[1])
|
||||||
|
else:
|
||||||
|
start = load_conv(buf, start, model[0])
|
||||||
|
elif block['type'] == 'connected':
|
||||||
|
model = self.models[ind]
|
||||||
|
if block['activation'] != 'linear':
|
||||||
|
start = load_fc(buf, start, model[0])
|
||||||
|
else:
|
||||||
|
start = load_fc(buf, start, model)
|
||||||
|
elif block['type'] == 'maxpool':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'reorg':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'upsample':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'route':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'shortcut':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'region':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'yolo':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'avgpool':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'softmax':
|
||||||
|
pass
|
||||||
|
elif block['type'] == 'cost':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print('unknown type %s' % (block['type']))
|
||||||
|
|
||||||
|
# def save_weights(self, outfile, cutoff=0):
|
||||||
|
# if cutoff <= 0:
|
||||||
|
# cutoff = len(self.blocks) - 1
|
||||||
|
#
|
||||||
|
# fp = open(outfile, 'wb')
|
||||||
|
# self.header[3] = self.seen
|
||||||
|
# header = self.header
|
||||||
|
# header.numpy().tofile(fp)
|
||||||
|
#
|
||||||
|
# ind = -1
|
||||||
|
# for blockId in range(1, cutoff + 1):
|
||||||
|
# ind = ind + 1
|
||||||
|
# block = self.blocks[blockId]
|
||||||
|
# if block['type'] == 'convolutional':
|
||||||
|
# model = self.models[ind]
|
||||||
|
# batch_normalize = int(block['batch_normalize'])
|
||||||
|
# if batch_normalize:
|
||||||
|
# save_conv_bn(fp, model[0], model[1])
|
||||||
|
# else:
|
||||||
|
# save_conv(fp, model[0])
|
||||||
|
# elif block['type'] == 'connected':
|
||||||
|
# model = self.models[ind]
|
||||||
|
# if block['activation'] != 'linear':
|
||||||
|
# save_fc(fc, model)
|
||||||
|
# else:
|
||||||
|
# save_fc(fc, model[0])
|
||||||
|
# elif block['type'] == 'maxpool':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'reorg':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'upsample':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'route':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'shortcut':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'region':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'yolo':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'avgpool':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'softmax':
|
||||||
|
# pass
|
||||||
|
# elif block['type'] == 'cost':
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# print('unknown type %s' % (block['type']))
|
||||||
|
# fp.close()
|
||||||
|
|
197
scripts/libs/region_loss.py
Normal file
197
scripts/libs/region_loss.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
import math
|
||||||
|
from torch.autograd.variable import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from libs.torch_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale,
|
||||||
|
sil_thresh, seen):
|
||||||
|
nB = target.size(0)
|
||||||
|
nA = num_anchors
|
||||||
|
nC = num_classes
|
||||||
|
anchor_step = len(anchors) / num_anchors
|
||||||
|
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
|
||||||
|
coord_mask = torch.zeros(nB, nA, nH, nW)
|
||||||
|
cls_mask = torch.zeros(nB, nA, nH, nW)
|
||||||
|
tx = torch.zeros(nB, nA, nH, nW)
|
||||||
|
ty = torch.zeros(nB, nA, nH, nW)
|
||||||
|
tw = torch.zeros(nB, nA, nH, nW)
|
||||||
|
th = torch.zeros(nB, nA, nH, nW)
|
||||||
|
tconf = torch.zeros(nB, nA, nH, nW)
|
||||||
|
tcls = torch.zeros(nB, nA, nH, nW)
|
||||||
|
|
||||||
|
nAnchors = nA * nH * nW
|
||||||
|
nPixels = nH * nW
|
||||||
|
for b in range(nB):
|
||||||
|
cur_pred_boxes = pred_boxes[b * nAnchors:(b + 1) * nAnchors].t()
|
||||||
|
cur_ious = torch.zeros(nAnchors)
|
||||||
|
for t in range(50):
|
||||||
|
if target[b][t * 5 + 1] == 0:
|
||||||
|
break
|
||||||
|
gx = target[b][t * 5 + 1] * nW
|
||||||
|
gy = target[b][t * 5 + 2] * nH
|
||||||
|
gw = target[b][t * 5 + 3] * nW
|
||||||
|
gh = target[b][t * 5 + 4] * nH
|
||||||
|
cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).repeat(nAnchors, 1).t()
|
||||||
|
cur_ious = torch.max(cur_ious, bbox_ious(cur_pred_boxes, cur_gt_boxes, x1y1x2y2=False))
|
||||||
|
conf_mask[b][cur_ious > sil_thresh] = 0
|
||||||
|
if seen < 12800:
|
||||||
|
if anchor_step == 4:
|
||||||
|
tx = torch.FloatTensor(anchors).view(nA, anchor_step).index_select(1, torch.LongTensor([2])).view(1, nA, 1,
|
||||||
|
1).repeat(
|
||||||
|
nB, 1, nH, nW)
|
||||||
|
ty = torch.FloatTensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([2])).view(
|
||||||
|
1, nA, 1, 1).repeat(nB, 1, nH, nW)
|
||||||
|
else:
|
||||||
|
tx.fill_(0.5)
|
||||||
|
ty.fill_(0.5)
|
||||||
|
tw.zero_()
|
||||||
|
th.zero_()
|
||||||
|
coord_mask.fill_(1)
|
||||||
|
|
||||||
|
nGT = 0
|
||||||
|
nCorrect = 0
|
||||||
|
for b in range(nB):
|
||||||
|
for t in range(50):
|
||||||
|
if target[b][t * 5 + 1] == 0:
|
||||||
|
break
|
||||||
|
nGT = nGT + 1
|
||||||
|
best_iou = 0.0
|
||||||
|
best_n = -1
|
||||||
|
min_dist = 10000
|
||||||
|
gx = target[b][t * 5 + 1] * nW
|
||||||
|
gy = target[b][t * 5 + 2] * nH
|
||||||
|
gi = int(gx)
|
||||||
|
gj = int(gy)
|
||||||
|
gw = target[b][t * 5 + 3] * nW
|
||||||
|
gh = target[b][t * 5 + 4] * nH
|
||||||
|
gt_box = [0, 0, gw, gh]
|
||||||
|
for n in range(nA):
|
||||||
|
aw = anchors[anchor_step * n]
|
||||||
|
ah = anchors[anchor_step * n + 1]
|
||||||
|
anchor_box = [0, 0, aw, ah]
|
||||||
|
iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False)
|
||||||
|
if anchor_step == 4:
|
||||||
|
ax = anchors[anchor_step * n + 2]
|
||||||
|
ay = anchors[anchor_step * n + 3]
|
||||||
|
dist = pow(((gi + ax) - gx), 2) + pow(((gj + ay) - gy), 2)
|
||||||
|
if iou > best_iou:
|
||||||
|
best_iou = iou
|
||||||
|
best_n = n
|
||||||
|
elif anchor_step == 4 and iou == best_iou and dist < min_dist:
|
||||||
|
best_iou = iou
|
||||||
|
best_n = n
|
||||||
|
min_dist = dist
|
||||||
|
|
||||||
|
gt_box = [gx, gy, gw, gh]
|
||||||
|
pred_box = pred_boxes[b * nAnchors + best_n * nPixels + gj * nW + gi]
|
||||||
|
|
||||||
|
coord_mask[b][best_n][gj][gi] = 1
|
||||||
|
cls_mask[b][best_n][gj][gi] = 1
|
||||||
|
conf_mask[b][best_n][gj][gi] = object_scale
|
||||||
|
tx[b][best_n][gj][gi] = target[b][t * 5 + 1] * nW - gi
|
||||||
|
ty[b][best_n][gj][gi] = target[b][t * 5 + 2] * nH - gj
|
||||||
|
tw[b][best_n][gj][gi] = math.log(gw / anchors[anchor_step * best_n])
|
||||||
|
th[b][best_n][gj][gi] = math.log(gh / anchors[anchor_step * best_n + 1])
|
||||||
|
iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) # best_iou
|
||||||
|
tconf[b][best_n][gj][gi] = iou
|
||||||
|
tcls[b][best_n][gj][gi] = target[b][t * 5]
|
||||||
|
if iou > 0.5:
|
||||||
|
nCorrect = nCorrect + 1
|
||||||
|
|
||||||
|
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls
|
||||||
|
|
||||||
|
|
||||||
|
class RegionLoss(nn.Module):
|
||||||
|
def __init__(self, num_classes=0, anchors=[], num_anchors=1):
|
||||||
|
super(RegionLoss, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.anchors = anchors
|
||||||
|
self.num_anchors = num_anchors
|
||||||
|
self.anchor_step = len(anchors) / num_anchors
|
||||||
|
self.coord_scale = 1
|
||||||
|
self.noobject_scale = 1
|
||||||
|
self.object_scale = 5
|
||||||
|
self.class_scale = 1
|
||||||
|
self.thresh = 0.6
|
||||||
|
self.seen = 0
|
||||||
|
|
||||||
|
def forward(self, output, target):
|
||||||
|
# output : BxAs*(4+1+num_classes)*H*W
|
||||||
|
t0 = time.time()
|
||||||
|
nB = output.data.size(0)
|
||||||
|
nA = self.num_anchors
|
||||||
|
nC = self.num_classes
|
||||||
|
nH = output.data.size(2)
|
||||||
|
nW = output.data.size(3)
|
||||||
|
|
||||||
|
output = output.view(nB, nA, (5 + nC), nH, nW)
|
||||||
|
x = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
|
||||||
|
y = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
|
||||||
|
w = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
|
||||||
|
h = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
|
||||||
|
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW))
|
||||||
|
cls = output.index_select(2, Variable(torch.linspace(5, 5 + nC - 1, nC).long().cuda()))
|
||||||
|
cls = cls.view(nB * nA, nC, nH * nW).transpose(1, 2).contiguous().view(nB * nA * nH * nW, nC)
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
pred_boxes = torch.cuda.FloatTensor(4, nB * nA * nH * nW)
|
||||||
|
grid_x = torch.linspace(0, nW - 1, nW).repeat(nH, 1).repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda()
|
||||||
|
grid_y = torch.linspace(0, nH - 1, nH).repeat(nW, 1).t().repeat(nB * nA, 1, 1).view(nB * nA * nH * nW).cuda()
|
||||||
|
anchor_w = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([0])).cuda()
|
||||||
|
anchor_h = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([1])).cuda()
|
||||||
|
anchor_w = anchor_w.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW)
|
||||||
|
anchor_h = anchor_h.repeat(nB, 1).repeat(1, 1, nH * nW).view(nB * nA * nH * nW)
|
||||||
|
pred_boxes[0] = x.data + grid_x
|
||||||
|
pred_boxes[1] = y.data + grid_y
|
||||||
|
pred_boxes[2] = torch.exp(w.data) * anchor_w
|
||||||
|
pred_boxes[3] = torch.exp(h.data) * anchor_h
|
||||||
|
pred_boxes = convert2cpu(pred_boxes.transpose(0, 1).contiguous().view(-1, 4))
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes,
|
||||||
|
target.data,
|
||||||
|
self.anchors, nA,
|
||||||
|
nC, \
|
||||||
|
nH, nW,
|
||||||
|
self.noobject_scale,
|
||||||
|
self.object_scale,
|
||||||
|
self.thresh,
|
||||||
|
self.seen)
|
||||||
|
cls_mask = (cls_mask == 1)
|
||||||
|
nProposals = int((conf > 0.25).sum().data[0])
|
||||||
|
|
||||||
|
tx = Variable(tx.cuda())
|
||||||
|
ty = Variable(ty.cuda())
|
||||||
|
tw = Variable(tw.cuda())
|
||||||
|
th = Variable(th.cuda())
|
||||||
|
tconf = Variable(tconf.cuda())
|
||||||
|
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
|
||||||
|
|
||||||
|
coord_mask = Variable(coord_mask.cuda())
|
||||||
|
conf_mask = Variable(conf_mask.cuda().sqrt())
|
||||||
|
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1, nC).cuda())
|
||||||
|
cls = cls[cls_mask].view(-1, nC)
|
||||||
|
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
loss_x = self.coord_scale * nn.MSELoss(reduction='sum')(x * coord_mask, tx * coord_mask) / 2.0
|
||||||
|
loss_y = self.coord_scale * nn.MSELoss(reduction='sum')(y * coord_mask, ty * coord_mask) / 2.0
|
||||||
|
loss_w = self.coord_scale * nn.MSELoss(reduction='sum')(w * coord_mask, tw * coord_mask) / 2.0
|
||||||
|
loss_h = self.coord_scale * nn.MSELoss(reduction='sum')(h * coord_mask, th * coord_mask) / 2.0
|
||||||
|
loss_conf = nn.MSELoss(reduction='sum')(conf * conf_mask, tconf * conf_mask) / 2.0
|
||||||
|
loss_cls = self.class_scale * nn.CrossEntropyLoss(reduction='sum')(cls, tcls)
|
||||||
|
loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
|
||||||
|
t4 = time.time()
|
||||||
|
if False:
|
||||||
|
print('-----------------------------------')
|
||||||
|
print(' activation : %f' % (t1 - t0))
|
||||||
|
print(' create pred_boxes : %f' % (t2 - t1))
|
||||||
|
print(' build targets : %f' % (t3 - t2))
|
||||||
|
print(' create loss : %f' % (t4 - t3))
|
||||||
|
print(' total : %f' % (t4 - t0))
|
||||||
|
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f' % (
|
||||||
|
self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_w.data[0], loss_h.data[0],
|
||||||
|
loss_conf.data[0], loss_cls.data[0], loss.data[0]))
|
||||||
|
return loss
|
234
scripts/libs/torch_utils.py
Normal file
234
scripts/libs/torch_utils.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def bbox_iou(box1, box2, x1y1x2y2=True):
|
||||||
|
|
||||||
|
# print('iou box1:', box1)
|
||||||
|
# print('iou box2:', box2)
|
||||||
|
|
||||||
|
if x1y1x2y2:
|
||||||
|
mx = min(box1[0], box2[0])
|
||||||
|
Mx = max(box1[2], box2[2])
|
||||||
|
my = min(box1[1], box2[1])
|
||||||
|
My = max(box1[3], box2[3])
|
||||||
|
w1 = box1[2] - box1[0]
|
||||||
|
h1 = box1[3] - box1[1]
|
||||||
|
w2 = box2[2] - box2[0]
|
||||||
|
h2 = box2[3] - box2[1]
|
||||||
|
else:
|
||||||
|
w1 = box1[2]
|
||||||
|
h1 = box1[3]
|
||||||
|
w2 = box2[2]
|
||||||
|
h2 = box2[3]
|
||||||
|
|
||||||
|
mx = min(box1[0], box2[0])
|
||||||
|
Mx = max(box1[0] + w1, box2[0] + w2)
|
||||||
|
my = min(box1[1], box2[1])
|
||||||
|
My = max(box1[1] + h1, box2[1] + h2)
|
||||||
|
uw = Mx - mx
|
||||||
|
uh = My - my
|
||||||
|
cw = w1 + w2 - uw
|
||||||
|
ch = h1 + h2 - uh
|
||||||
|
carea = 0
|
||||||
|
if cw <= 0 or ch <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
area1 = w1 * h1
|
||||||
|
area2 = w2 * h2
|
||||||
|
carea = cw * ch
|
||||||
|
uarea = area1 + area2 - carea
|
||||||
|
return carea / uarea
|
||||||
|
|
||||||
|
def bbox_ious(boxes1, boxes2, x1y1x2y2=True):
|
||||||
|
if x1y1x2y2:
|
||||||
|
mx = torch.min(boxes1[0], boxes2[0])
|
||||||
|
Mx = torch.max(boxes1[2], boxes2[2])
|
||||||
|
my = torch.min(boxes1[1], boxes2[1])
|
||||||
|
My = torch.max(boxes1[3], boxes2[3])
|
||||||
|
w1 = boxes1[2] - boxes1[0]
|
||||||
|
h1 = boxes1[3] - boxes1[1]
|
||||||
|
w2 = boxes2[2] - boxes2[0]
|
||||||
|
h2 = boxes2[3] - boxes2[1]
|
||||||
|
else:
|
||||||
|
mx = torch.min(boxes1[0] - boxes1[2] / 2.0, boxes2[0] - boxes2[2] / 2.0)
|
||||||
|
Mx = torch.max(boxes1[0] + boxes1[2] / 2.0, boxes2[0] + boxes2[2] / 2.0)
|
||||||
|
my = torch.min(boxes1[1] - boxes1[3] / 2.0, boxes2[1] - boxes2[3] / 2.0)
|
||||||
|
My = torch.max(boxes1[1] + boxes1[3] / 2.0, boxes2[1] + boxes2[3] / 2.0)
|
||||||
|
w1 = boxes1[2]
|
||||||
|
h1 = boxes1[3]
|
||||||
|
w2 = boxes2[2]
|
||||||
|
h2 = boxes2[3]
|
||||||
|
uw = Mx - mx
|
||||||
|
uh = My - my
|
||||||
|
cw = w1 + w2 - uw
|
||||||
|
ch = h1 + h2 - uh
|
||||||
|
mask = ((cw <= 0) + (ch <= 0) > 0)
|
||||||
|
area1 = w1 * h1
|
||||||
|
area2 = w2 * h2
|
||||||
|
carea = cw * ch
|
||||||
|
carea[mask] = 0
|
||||||
|
uarea = area1 + area2 - carea
|
||||||
|
return carea / uarea
|
||||||
|
|
||||||
|
|
||||||
|
def get_region_boxes(boxes_and_confs):
|
||||||
|
|
||||||
|
# print('Getting boxes from boxes and confs ...')
|
||||||
|
|
||||||
|
boxes_list = []
|
||||||
|
confs_list = []
|
||||||
|
|
||||||
|
for item in boxes_and_confs:
|
||||||
|
boxes_list.append(item[0])
|
||||||
|
confs_list.append(item[1])
|
||||||
|
|
||||||
|
# boxes: [batch, num1 + num2 + num3, 1, 4]
|
||||||
|
# confs: [batch, num1 + num2 + num3, num_classes]
|
||||||
|
boxes = torch.cat(boxes_list, dim=1)
|
||||||
|
confs = torch.cat(confs_list, dim=1)
|
||||||
|
|
||||||
|
return [boxes, confs]
|
||||||
|
|
||||||
|
|
||||||
|
def convert2cpu(gpu_matrix):
|
||||||
|
return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def convert2cpu_long(gpu_matrix):
|
||||||
|
return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):
|
||||||
|
# print(boxes.shape)
|
||||||
|
x1 = boxes[:, 0]
|
||||||
|
y1 = boxes[:, 1]
|
||||||
|
x2 = boxes[:, 2]
|
||||||
|
y2 = boxes[:, 3]
|
||||||
|
|
||||||
|
areas = (x2 - x1) * (y2 - y1)
|
||||||
|
order = confs.argsort()[::-1]
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
idx_self = order[0]
|
||||||
|
idx_other = order[1:]
|
||||||
|
|
||||||
|
keep.append(idx_self)
|
||||||
|
|
||||||
|
xx1 = np.maximum(x1[idx_self], x1[idx_other])
|
||||||
|
yy1 = np.maximum(y1[idx_self], y1[idx_other])
|
||||||
|
xx2 = np.minimum(x2[idx_self], x2[idx_other])
|
||||||
|
yy2 = np.minimum(y2[idx_self], y2[idx_other])
|
||||||
|
|
||||||
|
w = np.maximum(0.0, xx2 - xx1)
|
||||||
|
h = np.maximum(0.0, yy2 - yy1)
|
||||||
|
inter = w * h
|
||||||
|
|
||||||
|
if min_mode:
|
||||||
|
over = inter / np.minimum(areas[order[0]], areas[order[1:]])
|
||||||
|
else:
|
||||||
|
over = inter / (areas[order[0]] + areas[order[1:]] - inter)
|
||||||
|
|
||||||
|
inds = np.where(over <= nms_thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return np.array(keep)
|
||||||
|
|
||||||
|
def post_processing(img, conf_thresh, nms_thresh, output):
|
||||||
|
|
||||||
|
# anchors = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]
|
||||||
|
# num_anchors = 9
|
||||||
|
# anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||||
|
# strides = [8, 16, 32]
|
||||||
|
# anchor_step = len(anchors) // num_anchors
|
||||||
|
|
||||||
|
# [batch, num, 1, 4]
|
||||||
|
box_array = output[0]
|
||||||
|
# [batch, num, num_classes]
|
||||||
|
confs = output[1]
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
if type(box_array).__name__ != 'ndarray':
|
||||||
|
box_array = box_array.cpu().detach().numpy()
|
||||||
|
confs = confs.cpu().detach().numpy()
|
||||||
|
|
||||||
|
num_classes = confs.shape[2]
|
||||||
|
|
||||||
|
# [batch, num, 4]
|
||||||
|
box_array = box_array[:, :, 0]
|
||||||
|
|
||||||
|
# [batch, num, num_classes] --> [batch, num]
|
||||||
|
max_conf = np.max(confs, axis=2)
|
||||||
|
max_id = np.argmax(confs, axis=2)
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
bboxes_batch = []
|
||||||
|
for i in range(box_array.shape[0]):
|
||||||
|
|
||||||
|
argwhere = max_conf[i] > conf_thresh
|
||||||
|
l_box_array = box_array[i, argwhere, :]
|
||||||
|
l_max_conf = max_conf[i, argwhere]
|
||||||
|
l_max_id = max_id[i, argwhere]
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
# nms for each class
|
||||||
|
for j in range(num_classes):
|
||||||
|
|
||||||
|
cls_argwhere = l_max_id == j
|
||||||
|
ll_box_array = l_box_array[cls_argwhere, :]
|
||||||
|
ll_max_conf = l_max_conf[cls_argwhere]
|
||||||
|
ll_max_id = l_max_id[cls_argwhere]
|
||||||
|
|
||||||
|
keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh)
|
||||||
|
|
||||||
|
if (keep.size > 0):
|
||||||
|
ll_box_array = ll_box_array[keep, :]
|
||||||
|
ll_max_conf = ll_max_conf[keep]
|
||||||
|
ll_max_id = ll_max_id[keep]
|
||||||
|
|
||||||
|
for k in range(ll_box_array.shape[0]):
|
||||||
|
bboxes.append([ll_box_array[k, 0], ll_box_array[k, 1], ll_box_array[k, 2], ll_box_array[k, 3], ll_max_conf[k], ll_max_conf[k], ll_max_id[k]])
|
||||||
|
|
||||||
|
bboxes_batch.append(bboxes)
|
||||||
|
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
print('-----------------------------------')
|
||||||
|
print(' max and argmax : %f' % (t2 - t1))
|
||||||
|
print(' nms : %f' % (t3 - t2))
|
||||||
|
print('Post processing total : %f' % (t3 - t1))
|
||||||
|
print('-----------------------------------')
|
||||||
|
|
||||||
|
return bboxes_batch
|
||||||
|
|
||||||
|
def do_detect(model, img, conf_thresh, nms_thresh, use_cuda=1):
|
||||||
|
model.eval()
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
if type(img) == np.ndarray and len(img.shape) == 3: # cv2 image
|
||||||
|
img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
|
||||||
|
elif type(img) == np.ndarray and len(img.shape) == 4:
|
||||||
|
img = torch.from_numpy(img.transpose(0, 3, 1, 2)).float().div(255.0)
|
||||||
|
else:
|
||||||
|
print("unknow image type")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
img = img.cuda()
|
||||||
|
img = torch.autograd.Variable(img)
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
output = model(img)
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
print('-----------------------------------')
|
||||||
|
print(' Preprocess : %f' % (t1 - t0))
|
||||||
|
print(' Model Inference : %f' % (t2 - t1))
|
||||||
|
print('-----------------------------------')
|
||||||
|
|
||||||
|
return utils.post_processing(img, conf_thresh, nms_thresh, output)
|
323
scripts/libs/yolo_layer.py
Normal file
323
scripts/libs/yolo_layer.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from libs.torch_utils import *
|
||||||
|
|
||||||
|
def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
|
||||||
|
validation=False):
|
||||||
|
# Output would be invalid if it does not satisfy this assert
|
||||||
|
# assert (output.size(1) == (5 + num_classes) * num_anchors)
|
||||||
|
|
||||||
|
# print(output.size())
|
||||||
|
|
||||||
|
# Slice the second dimension (channel) of output into:
|
||||||
|
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
|
||||||
|
# And then into
|
||||||
|
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
|
||||||
|
batch = output.size(0)
|
||||||
|
H = output.size(2)
|
||||||
|
W = output.size(3)
|
||||||
|
|
||||||
|
bxy_list = []
|
||||||
|
bwh_list = []
|
||||||
|
det_confs_list = []
|
||||||
|
cls_confs_list = []
|
||||||
|
|
||||||
|
for i in range(num_anchors):
|
||||||
|
begin = i * (5 + num_classes)
|
||||||
|
end = (i + 1) * (5 + num_classes)
|
||||||
|
|
||||||
|
bxy_list.append(output[:, begin : begin + 2])
|
||||||
|
bwh_list.append(output[:, begin + 2 : begin + 4])
|
||||||
|
det_confs_list.append(output[:, begin + 4 : begin + 5])
|
||||||
|
cls_confs_list.append(output[:, begin + 5 : end])
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * 2, H, W]
|
||||||
|
bxy = torch.cat(bxy_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors * 2, H, W]
|
||||||
|
bwh = torch.cat(bwh_list, dim=1)
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
det_confs = torch.cat(det_confs_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors * H * W]
|
||||||
|
det_confs = det_confs.view(batch, num_anchors * H * W)
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * num_classes, H, W]
|
||||||
|
cls_confs = torch.cat(cls_confs_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, num_classes, H * W]
|
||||||
|
cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W)
|
||||||
|
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
|
||||||
|
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, num_classes)
|
||||||
|
|
||||||
|
# Apply sigmoid(), exp() and softmax() to slices
|
||||||
|
#
|
||||||
|
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
|
||||||
|
bwh = torch.exp(bwh)
|
||||||
|
det_confs = torch.sigmoid(det_confs)
|
||||||
|
cls_confs = torch.sigmoid(cls_confs)
|
||||||
|
|
||||||
|
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
|
||||||
|
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0)
|
||||||
|
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0), axis=0)
|
||||||
|
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
|
||||||
|
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
|
||||||
|
|
||||||
|
anchor_w = []
|
||||||
|
anchor_h = []
|
||||||
|
for i in range(num_anchors):
|
||||||
|
anchor_w.append(anchors[i * 2])
|
||||||
|
anchor_h.append(anchors[i * 2 + 1])
|
||||||
|
|
||||||
|
device = None
|
||||||
|
cuda_check = output.is_cuda
|
||||||
|
if cuda_check:
|
||||||
|
device = output.get_device()
|
||||||
|
|
||||||
|
bx_list = []
|
||||||
|
by_list = []
|
||||||
|
bw_list = []
|
||||||
|
bh_list = []
|
||||||
|
|
||||||
|
# Apply C-x, C-y, P-w, P-h
|
||||||
|
for i in range(num_anchors):
|
||||||
|
ii = i * 2
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bw = bwh[:, ii : ii + 1] * anchor_w[i]
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
|
||||||
|
|
||||||
|
bx_list.append(bx)
|
||||||
|
by_list.append(by)
|
||||||
|
bw_list.append(bw)
|
||||||
|
bh_list.append(bh)
|
||||||
|
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# Figure out bboxes from slices #
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bx = torch.cat(bx_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
by = torch.cat(by_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bw = torch.cat(bw_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bh = torch.cat(bh_list, dim=1)
|
||||||
|
|
||||||
|
# Shape: [batch, 2 * num_anchors, H, W]
|
||||||
|
bx_bw = torch.cat((bx, bw), dim=1)
|
||||||
|
# Shape: [batch, 2 * num_anchors, H, W]
|
||||||
|
by_bh = torch.cat((by, bh), dim=1)
|
||||||
|
|
||||||
|
# normalize coordinates to [0, 1]
|
||||||
|
bx_bw /= W
|
||||||
|
by_bh /= H
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * H * W, 1]
|
||||||
|
bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
|
||||||
|
by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
|
||||||
|
bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
|
||||||
|
bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
|
||||||
|
|
||||||
|
bx1 = bx - bw * 0.5
|
||||||
|
by1 = by - bh * 0.5
|
||||||
|
bx2 = bx1 + bw
|
||||||
|
by2 = by1 + bh
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
|
||||||
|
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4)
|
||||||
|
# boxes = boxes.repeat(1, 1, num_classes, 1)
|
||||||
|
|
||||||
|
# boxes: [batch, num_anchors * H * W, 1, 4]
|
||||||
|
# cls_confs: [batch, num_anchors * H * W, num_classes]
|
||||||
|
# det_confs: [batch, num_anchors * H * W]
|
||||||
|
|
||||||
|
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
|
||||||
|
confs = cls_confs * det_confs
|
||||||
|
|
||||||
|
# boxes: [batch, num_anchors * H * W, 1, 4]
|
||||||
|
# confs: [batch, num_anchors * H * W, num_classes]
|
||||||
|
|
||||||
|
return boxes, confs
|
||||||
|
|
||||||
|
|
||||||
|
def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
|
||||||
|
validation=False):
|
||||||
|
# Output would be invalid if it does not satisfy this assert
|
||||||
|
# assert (output.size(1) == (5 + num_classes) * num_anchors)
|
||||||
|
|
||||||
|
# print(output.size())
|
||||||
|
|
||||||
|
# Slice the second dimension (channel) of output into:
|
||||||
|
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
|
||||||
|
# And then into
|
||||||
|
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
|
||||||
|
# batch = output.size(0)
|
||||||
|
# H = output.size(2)
|
||||||
|
# W = output.size(3)
|
||||||
|
|
||||||
|
bxy_list = []
|
||||||
|
bwh_list = []
|
||||||
|
det_confs_list = []
|
||||||
|
cls_confs_list = []
|
||||||
|
|
||||||
|
for i in range(num_anchors):
|
||||||
|
begin = i * (5 + num_classes)
|
||||||
|
end = (i + 1) * (5 + num_classes)
|
||||||
|
|
||||||
|
bxy_list.append(output[:, begin : begin + 2])
|
||||||
|
bwh_list.append(output[:, begin + 2 : begin + 4])
|
||||||
|
det_confs_list.append(output[:, begin + 4 : begin + 5])
|
||||||
|
cls_confs_list.append(output[:, begin + 5 : end])
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * 2, H, W]
|
||||||
|
bxy = torch.cat(bxy_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors * 2, H, W]
|
||||||
|
bwh = torch.cat(bwh_list, dim=1)
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
det_confs = torch.cat(det_confs_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors * H * W]
|
||||||
|
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3))
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * num_classes, H, W]
|
||||||
|
cls_confs = torch.cat(cls_confs_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, num_classes, H * W]
|
||||||
|
cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3))
|
||||||
|
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
|
||||||
|
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes)
|
||||||
|
|
||||||
|
# Apply sigmoid(), exp() and softmax() to slices
|
||||||
|
#
|
||||||
|
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
|
||||||
|
bwh = torch.exp(bwh)
|
||||||
|
det_confs = torch.sigmoid(det_confs)
|
||||||
|
cls_confs = torch.sigmoid(cls_confs)
|
||||||
|
|
||||||
|
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
|
||||||
|
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0)
|
||||||
|
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0)
|
||||||
|
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
|
||||||
|
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
|
||||||
|
|
||||||
|
anchor_w = []
|
||||||
|
anchor_h = []
|
||||||
|
for i in range(num_anchors):
|
||||||
|
anchor_w.append(anchors[i * 2])
|
||||||
|
anchor_h.append(anchors[i * 2 + 1])
|
||||||
|
|
||||||
|
device = None
|
||||||
|
cuda_check = output.is_cuda
|
||||||
|
if cuda_check:
|
||||||
|
device = output.get_device()
|
||||||
|
|
||||||
|
bx_list = []
|
||||||
|
by_list = []
|
||||||
|
bw_list = []
|
||||||
|
bh_list = []
|
||||||
|
|
||||||
|
# Apply C-x, C-y, P-w, P-h
|
||||||
|
for i in range(num_anchors):
|
||||||
|
ii = i * 2
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bw = bwh[:, ii : ii + 1] * anchor_w[i]
|
||||||
|
# Shape: [batch, 1, H, W]
|
||||||
|
bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
|
||||||
|
|
||||||
|
bx_list.append(bx)
|
||||||
|
by_list.append(by)
|
||||||
|
bw_list.append(bw)
|
||||||
|
bh_list.append(bh)
|
||||||
|
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# Figure out bboxes from slices #
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bx = torch.cat(bx_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
by = torch.cat(by_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bw = torch.cat(bw_list, dim=1)
|
||||||
|
# Shape: [batch, num_anchors, H, W]
|
||||||
|
bh = torch.cat(bh_list, dim=1)
|
||||||
|
|
||||||
|
# Shape: [batch, 2 * num_anchors, H, W]
|
||||||
|
bx_bw = torch.cat((bx, bw), dim=1)
|
||||||
|
# Shape: [batch, 2 * num_anchors, H, W]
|
||||||
|
by_bh = torch.cat((by, bh), dim=1)
|
||||||
|
|
||||||
|
# normalize coordinates to [0, 1]
|
||||||
|
bx_bw /= output.size(3)
|
||||||
|
by_bh /= output.size(2)
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * H * W, 1]
|
||||||
|
bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
|
||||||
|
by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
|
||||||
|
bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
|
||||||
|
bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
|
||||||
|
|
||||||
|
bx1 = bx - bw * 0.5
|
||||||
|
by1 = by - bh * 0.5
|
||||||
|
bx2 = bx1 + bw
|
||||||
|
by2 = by1 + bh
|
||||||
|
|
||||||
|
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
|
||||||
|
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4)
|
||||||
|
# boxes = boxes.repeat(1, 1, num_classes, 1)
|
||||||
|
|
||||||
|
# boxes: [batch, num_anchors * H * W, 1, 4]
|
||||||
|
# cls_confs: [batch, num_anchors * H * W, num_classes]
|
||||||
|
# det_confs: [batch, num_anchors * H * W]
|
||||||
|
|
||||||
|
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
|
||||||
|
confs = cls_confs * det_confs
|
||||||
|
|
||||||
|
# boxes: [batch, num_anchors * H * W, 1, 4]
|
||||||
|
# confs: [batch, num_anchors * H * W, num_classes]
|
||||||
|
|
||||||
|
return boxes, confs
|
||||||
|
|
||||||
|
class YoloLayer(nn.Module):
|
||||||
|
''' Yolo layer
|
||||||
|
model_out: while inference,is post-processing inside or outside the model
|
||||||
|
true:outside
|
||||||
|
'''
|
||||||
|
def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False):
|
||||||
|
super(YoloLayer, self).__init__()
|
||||||
|
self.anchor_mask = anchor_mask
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.anchors = anchors
|
||||||
|
self.num_anchors = num_anchors
|
||||||
|
self.anchor_step = len(anchors) // num_anchors
|
||||||
|
self.coord_scale = 1
|
||||||
|
self.noobject_scale = 1
|
||||||
|
self.object_scale = 5
|
||||||
|
self.class_scale = 1
|
||||||
|
self.thresh = 0.6
|
||||||
|
self.stride = stride
|
||||||
|
self.seen = 0
|
||||||
|
self.scale_x_y = 1
|
||||||
|
|
||||||
|
self.model_out = model_out
|
||||||
|
|
||||||
|
def forward(self, output, target=None):
|
||||||
|
if self.training:
|
||||||
|
return output
|
||||||
|
masked_anchors = []
|
||||||
|
for m in self.anchor_mask:
|
||||||
|
masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
|
||||||
|
masked_anchors = [anchor / self.stride for anchor in masked_anchors]
|
||||||
|
|
||||||
|
return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)
|
34
scripts/make_yolo_onnx.py
Normal file
34
scripts/make_yolo_onnx.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
from libs.models import Darknet
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_onnx(cfgfile: str, weightsfile: str, IMAGE_WIDTH: int, IMAGE_HEIGHT: int):
|
||||||
|
model = Darknet(cfgfile)
|
||||||
|
model.load_weights(weightsfile)
|
||||||
|
|
||||||
|
x = torch.randn((1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), requires_grad=True)
|
||||||
|
|
||||||
|
onnx_filename = '../models/yolo.onnx'
|
||||||
|
|
||||||
|
print('Export the onnx model ...')
|
||||||
|
|
||||||
|
torch.onnx.export(model,
|
||||||
|
x,
|
||||||
|
onnx_filename,
|
||||||
|
export_params=True,
|
||||||
|
opset_version=11,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['boxes', 'confs'],
|
||||||
|
dynamic_axes={
|
||||||
|
'input': {0: 'batch_size'},
|
||||||
|
'boxes': {0: 'batch_size'},
|
||||||
|
'confs': {0: 'batch_size'},
|
||||||
|
})
|
||||||
|
|
||||||
|
print('Onnx model exporting done!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
convert_to_onnx("./yolov4-mish.cfg", "./yolov4-mish.weights", 512, 512)
|
1161
scripts/yolov4-mish.cfg
Normal file
1161
scripts/yolov4-mish.cfg
Normal file
File diff suppressed because it is too large
Load Diff
291
src/bbox.rs
Normal file
291
src/bbox.rs
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_derive::{Deserialize, Serialize};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
pub trait BBoxFormat: std::fmt::Debug {}
|
||||||
|
|
||||||
|
/// Left-top-width-height format, contains left top corner and width-height
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||||
|
pub struct Ltwh;
|
||||||
|
impl BBoxFormat for Ltwh {}
|
||||||
|
|
||||||
|
/// X-y-aspect_ratio-height format, contains coordinates of the center of bbox and aspect_ratio-height
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||||
|
pub struct Xyah;
|
||||||
|
impl BBoxFormat for Xyah {}
|
||||||
|
|
||||||
|
/// Left-top-right-bottom format, contains left top and right bottom corners
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||||
|
pub struct Ltrb;
|
||||||
|
impl BBoxFormat for Ltrb {}
|
||||||
|
|
||||||
|
/// X-y-width-height format, contains coordinates of the center of bbox and width-height
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||||
|
pub struct Xywh;
|
||||||
|
impl BBoxFormat for Xywh {}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
|
pub struct BBox<F: BBoxFormat + Serialize + Deserialize<'static> + PartialEq>(
|
||||||
|
[f32; 4],
|
||||||
|
PhantomData<F>,
|
||||||
|
);
|
||||||
|
|
||||||
|
impl<F: BBoxFormat + Serialize + Deserialize<'static> + PartialEq> From<BBox<F>> for [f32; 4] {
|
||||||
|
fn from(bbox: BBox<F>) -> Self {
|
||||||
|
bbox.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: BBoxFormat + Serialize + Deserialize<'static> + PartialEq> BBox<F> {
|
||||||
|
#[inline]
|
||||||
|
pub fn as_slice(&self) -> &[f32; 4] {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use carefully when you REALLY sure that slice have needed format
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn assigned(slice: &[f32; 4]) -> Self {
|
||||||
|
BBox(*slice, Default::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Ltwh> {
|
||||||
|
#[inline]
|
||||||
|
pub fn ltwh(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox([x1, x2, x3, x4], Default::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn left(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn top(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn width(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn height(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_xyah(&self) -> BBox<Xyah> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_ltrb(&self) -> BBox<Ltrb> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
pub fn ltrb(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox([x1, x2, x3, x4], Default::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_ltwh(&self) -> BBox<Ltwh> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_xyah(&self) -> BBox<Xyah> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn left(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn top(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn right(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn bottom(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Xyah> {
|
||||||
|
#[inline]
|
||||||
|
pub fn xyah(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox([x1, x2, x3, x4], Default::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_ltrb(&self) -> BBox<Ltrb> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_ltwh(&self) -> BBox<Ltwh> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cx(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cy(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn aspect_ratio(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn height(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Xywh> {
|
||||||
|
#[inline]
|
||||||
|
pub fn xywh(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox([x1, x2, x3, x4], Default::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_xyah(&self) -> BBox<Xyah> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cx(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cy(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn width(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn height(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Ltwh>> for BBox<Xyah> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltwh>) -> Self {
|
||||||
|
Self(
|
||||||
|
[
|
||||||
|
v.0[0] + v.0[2] / 2.0,
|
||||||
|
v.0[1] + v.0[3] / 2.0,
|
||||||
|
v.0[2] / v.0[3],
|
||||||
|
v.0[3],
|
||||||
|
],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Ltrb>> for BBox<Xyah> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltrb>) -> Self {
|
||||||
|
Self(
|
||||||
|
[
|
||||||
|
v.0[0] + (v.0[2] - v.0[0]) / 2.0,
|
||||||
|
v.0[1] + (v.0[3] - v.0[1]) / 2.0,
|
||||||
|
(v.0[2] - v.0[0]) / (v.0[3] - v.0[1]),
|
||||||
|
v.0[3] - v.0[1],
|
||||||
|
],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Ltwh>> for BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltwh>) -> Self {
|
||||||
|
Self(
|
||||||
|
[v.0[0], v.0[1], v.0[2] + v.0[0], v.0[3] + v.0[1]],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Xyah>> for BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xyah>) -> Self {
|
||||||
|
Self(
|
||||||
|
[
|
||||||
|
v.0[0] - v.0[2] * v.0[3] / 2.,
|
||||||
|
v.0[1] - v.0[3] / 2.,
|
||||||
|
v.0[0] + v.0[2] * v.0[3] / 2.,
|
||||||
|
v.0[1] + v.0[3] / 2.,
|
||||||
|
],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Ltrb>> for BBox<Ltwh> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltrb>) -> Self {
|
||||||
|
Self(
|
||||||
|
[v.0[0], v.0[1], v.0[2] - v.0[0], v.0[3] - v.0[1]],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Xyah>> for BBox<Ltwh> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xyah>) -> Self {
|
||||||
|
let height = v.0[3];
|
||||||
|
let width = v.0[2] * height;
|
||||||
|
|
||||||
|
Self(
|
||||||
|
[v.0[0] - width / 2.0, v.0[1] - height / 2.0, width, height],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Xywh>> for BBox<Xyah> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xywh>) -> Self {
|
||||||
|
Self(
|
||||||
|
[v.0[0], v.0[1], v.0[2] / v.0[3], v.0[3]],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a BBox<Xyah>> for BBox<Xywh> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xyah>) -> Self {
|
||||||
|
Self(
|
||||||
|
[v.0[0], v.0[1], v.0[2] * v.0[3], v.0[3]],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
100
src/circular_queue.rs
Normal file
100
src/circular_queue.rs
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
pub struct CircularQueue<T> {
|
||||||
|
deque: VecDeque<T>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> Clone for CircularQueue<T> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
deque: self.deque.clone(),
|
||||||
|
capacity: self.capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: fmt::Debug> fmt::Debug for CircularQueue<T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.deque.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CircularQueue<T> {
|
||||||
|
#[inline]
|
||||||
|
pub fn with_capacity(cap: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
deque: VecDeque::with_capacity(cap),
|
||||||
|
capacity: cap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn push(&mut self, item: T) -> Option<T> {
|
||||||
|
let poped = if self.is_full() {
|
||||||
|
self.deque.pop_back()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
self.deque.push_front(item);
|
||||||
|
|
||||||
|
poped
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.deque.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.deque.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn is_full(&self) -> bool {
|
||||||
|
self.deque.len() == self.capacity
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn capacity(&self) -> usize {
|
||||||
|
self.capacity
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn pop(&mut self) -> Option<T> {
|
||||||
|
self.deque.pop_front()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.deque.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn top_mut(&mut self) -> Option<&mut T> {
|
||||||
|
self.deque.front_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
|
||||||
|
self.deque.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
|
||||||
|
self.deque.iter_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn asc_iter(&self) -> impl Iterator<Item = &'_ T> {
|
||||||
|
self.deque.iter().rev()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn asc_iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
|
||||||
|
self.deque.iter_mut().rev()
|
||||||
|
}
|
||||||
|
}
|
58
src/detection.rs
Normal file
58
src/detection.rs
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
use serde_derive::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::bbox::{BBox, Xywh};
|
||||||
|
|
||||||
|
/// Contains (x,y) of the center and (width,height) of bbox
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
|
||||||
|
pub struct Detection {
|
||||||
|
pub x: f32,
|
||||||
|
pub y: f32,
|
||||||
|
pub w: f32,
|
||||||
|
pub h: f32,
|
||||||
|
#[serde(rename = "p")]
|
||||||
|
pub confidence: f32,
|
||||||
|
#[serde(rename = "c")]
|
||||||
|
pub class: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Detection {
|
||||||
|
pub fn iou(&self, other: &Detection) -> f32 {
|
||||||
|
let b1_area = (self.w + 1.) * (self.h + 1.);
|
||||||
|
let (xmin, xmax, ymin, ymax) = (self.xmin(), self.xmax(), self.ymin(), self.ymax());
|
||||||
|
|
||||||
|
let b2_area = (other.w + 1.) * (other.h + 1.);
|
||||||
|
|
||||||
|
let i_xmin = xmin.max(other.xmin());
|
||||||
|
let i_xmax = xmax.min(other.xmax());
|
||||||
|
let i_ymin = ymin.max(other.ymin());
|
||||||
|
let i_ymax = ymax.min(other.ymax());
|
||||||
|
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||||
|
|
||||||
|
(i_area) / (b1_area + b2_area - i_area)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn bbox(&self) -> BBox<Xywh> {
|
||||||
|
BBox::xywh(self.x, self.y, self.w, self.h)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn xmax(&self) -> f32 {
|
||||||
|
self.x + self.w / 2.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn ymax(&self) -> f32 {
|
||||||
|
self.y + self.h / 2.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn xmin(&self) -> f32 {
|
||||||
|
self.x - self.w / 2.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn ymin(&self) -> f32 {
|
||||||
|
self.y - self.h / 2.
|
||||||
|
}
|
||||||
|
}
|
248
src/detector.rs
Normal file
248
src/detector.rs
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
use crate::detection::Detection;
|
||||||
|
use crate::error::Error;
|
||||||
|
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use onnx_model::*;
|
||||||
|
|
||||||
|
const MODEL_DYNAMIC_INPUT_DIMENSION: i64 = -1;
|
||||||
|
|
||||||
|
pub struct YoloDetectorConfig {
|
||||||
|
pub confidence_threshold: f32,
|
||||||
|
pub iou_threshold: f32,
|
||||||
|
pub classes: Vec<i32>,
|
||||||
|
pub class_map: Option<Vec<i32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YoloDetectorConfig {
|
||||||
|
pub fn new(confidence_threshold: f32, classes: Vec<i32>) -> Self {
|
||||||
|
Self {
|
||||||
|
confidence_threshold,
|
||||||
|
iou_threshold: 0.2,
|
||||||
|
classes,
|
||||||
|
class_map: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct YoloDetector {
|
||||||
|
model: OnnxInferenceModel,
|
||||||
|
config: YoloDetectorConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YoloDetector {
|
||||||
|
pub fn new(
|
||||||
|
model_src: &str,
|
||||||
|
config: YoloDetectorConfig,
|
||||||
|
device: InferenceDevice,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
let model = OnnxInferenceModel::new(model_src, device)?;
|
||||||
|
|
||||||
|
Ok(Self { model, config })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_model_input_size(&self) -> Option<(u32, u32)> {
|
||||||
|
let mut input_dims = self.model.get_input_infos()[0].shape.dims.clone();
|
||||||
|
let input_height = input_dims.pop().unwrap();
|
||||||
|
let input_width = input_dims.pop().unwrap();
|
||||||
|
if input_height == MODEL_DYNAMIC_INPUT_DIMENSION
|
||||||
|
&& input_width == MODEL_DYNAMIC_INPUT_DIMENSION
|
||||||
|
{
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some((input_width as u32, input_height as u32))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn detect(
|
||||||
|
&self,
|
||||||
|
frames: ArrayView4<'_, f32>,
|
||||||
|
fw: i32,
|
||||||
|
fh: i32,
|
||||||
|
with_crop: bool,
|
||||||
|
) -> Result<Vec<Vec<Detection>>, Error> {
|
||||||
|
let in_shape = frames.shape();
|
||||||
|
let (in_w, in_h) = (in_shape[3], in_shape[2]);
|
||||||
|
let preditions = self.model.run(&[frames.into_dyn()])?.pop().unwrap();
|
||||||
|
let shape = preditions.shape();
|
||||||
|
let shape = [shape[0], shape[1], shape[2]];
|
||||||
|
let arr = preditions.into_shape(shape).unwrap();
|
||||||
|
let bboxes = self.postprocess(arr.view(), in_w, in_h, fw, fh, with_crop)?;
|
||||||
|
|
||||||
|
Ok(bboxes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn postprocess(
|
||||||
|
&self,
|
||||||
|
view: ArrayView3<'_, f32>,
|
||||||
|
in_w: usize,
|
||||||
|
in_h: usize,
|
||||||
|
frame_width: i32,
|
||||||
|
frame_height: i32,
|
||||||
|
with_crop: bool,
|
||||||
|
) -> Result<Vec<Vec<Detection>>, Error> {
|
||||||
|
let shape = view.shape();
|
||||||
|
let nbatches = shape[0];
|
||||||
|
let npreds = shape[1];
|
||||||
|
let pred_size = shape[2];
|
||||||
|
let mut results: Vec<Vec<Detection>> = (0..nbatches).map(|_| vec![]).collect();
|
||||||
|
|
||||||
|
let (ox, oy, ow, oh) = if with_crop {
|
||||||
|
let in_a = in_h as f32 / in_w as f32;
|
||||||
|
let frame_a = frame_height as f32 / frame_width as f32;
|
||||||
|
|
||||||
|
if in_a > frame_a {
|
||||||
|
let w = frame_height as f32 / in_a;
|
||||||
|
((frame_width as f32 - w) / 2.0, 0.0, w, frame_height as f32)
|
||||||
|
} else {
|
||||||
|
let h = frame_width as f32 * in_a;
|
||||||
|
(0.0, (frame_height as f32 - h) / 2.0, frame_width as f32, h)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(0.0, 0.0, frame_width as f32, frame_height as f32)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
|
for batch in 0..nbatches {
|
||||||
|
let results = &mut results[batch];
|
||||||
|
|
||||||
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
let mut bboxes: Vec<Vec<Detection>> = (0..80).map(|_| vec![]).collect();
|
||||||
|
|
||||||
|
for index in 0..npreds {
|
||||||
|
let x_0 = view.index_axis(Axis(0), batch);
|
||||||
|
let x_1 = x_0.index_axis(Axis(0), index);
|
||||||
|
let detection = x_1.as_slice().unwrap();
|
||||||
|
|
||||||
|
let (x, y, w, h) = match &detection[0..4] {
|
||||||
|
[center_x, center_y, width, height] => {
|
||||||
|
let center_x = ox + center_x * ow;
|
||||||
|
let center_y = oy + center_y * oh;
|
||||||
|
let width = width * ow as f32;
|
||||||
|
let height = height * oh as f32;
|
||||||
|
|
||||||
|
(center_x, center_y, width, height)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let classes = &detection[4..pred_size];
|
||||||
|
|
||||||
|
let mut class_index = -1;
|
||||||
|
let mut confidence = 0.0;
|
||||||
|
|
||||||
|
for (idx, val) in classes.iter().copied().enumerate() {
|
||||||
|
if val > confidence {
|
||||||
|
class_index = idx as i32;
|
||||||
|
confidence = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if class_index > -1 && confidence > self.config.confidence_threshold {
|
||||||
|
if !self.config.classes.contains(&class_index) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if w * h > ((frame_width / 2) * (frame_height / 2)) as f32 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mapped_class = match &self.config.class_map {
|
||||||
|
Some(map) => map
|
||||||
|
.get(class_index as usize)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(class_index),
|
||||||
|
None => class_index,
|
||||||
|
};
|
||||||
|
|
||||||
|
bboxes[mapped_class as usize].push(Detection {
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
w,
|
||||||
|
h,
|
||||||
|
confidence,
|
||||||
|
class: class_index as _,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for mut dets in bboxes.into_iter() {
|
||||||
|
if dets.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if dets.len() == 1 {
|
||||||
|
results.append(&mut dets);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let indices = self.non_maximum_supression(&mut dets)?;
|
||||||
|
|
||||||
|
results.extend(dets.drain(..).enumerate().filter_map(|(idx, item)| {
|
||||||
|
if indices.contains(&(idx as i32)) {
|
||||||
|
Some(item)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// for (det, idx) in dets.into_iter().zip(indices) {
|
||||||
|
// if idx > -1 {
|
||||||
|
// results.push(det);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result<Vec<i32>, Error> {
|
||||||
|
// let mut rects = core::Vector::new();
|
||||||
|
// let mut scores = core::Vector::new();
|
||||||
|
|
||||||
|
// for det in dets {
|
||||||
|
// rects.push(core::Rect2d::new(
|
||||||
|
// det.xmin as f64,
|
||||||
|
// det.ymin as f64,
|
||||||
|
// (det.xmax - det.xmin) as f64,
|
||||||
|
// (det.ymax - det.ymin) as f64
|
||||||
|
// ));
|
||||||
|
// scores.push(det.confidence);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let mut indices = core::Vector::<i32>::new();
|
||||||
|
// dnn::nms_boxes_f64(
|
||||||
|
// &rects,
|
||||||
|
// &scores,
|
||||||
|
// self.config.confidence_threshold,
|
||||||
|
// self.config.iou_threshold,
|
||||||
|
// &mut indices,
|
||||||
|
// 1.0,
|
||||||
|
// 0
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
// Ok(indices.to_vec())
|
||||||
|
// }
|
||||||
|
|
||||||
|
fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result<Vec<i32>, Error> {
|
||||||
|
dets.sort_unstable_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
|
||||||
|
|
||||||
|
let mut retain: Vec<_> = (0..dets.len() as i32).collect();
|
||||||
|
for idx in 0..dets.len() - 1 {
|
||||||
|
if retain[idx] != -1 {
|
||||||
|
for r in retain[idx + 1..].iter_mut() {
|
||||||
|
if *r != -1 {
|
||||||
|
let iou = dets[idx].iou(&dets[*r as usize]);
|
||||||
|
if iou > self.config.iou_threshold {
|
||||||
|
*r = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
retain.retain(|&x| x > -1);
|
||||||
|
Ok(retain)
|
||||||
|
}
|
||||||
|
}
|
7
src/error.rs
Normal file
7
src/error.rs
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("OnnxModel Error: {0}")]
|
||||||
|
OnnxModelError(#[from] onnx_model::error::Error),
|
||||||
|
}
|
24
src/frame.rs
Normal file
24
src/frame.rs
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
use crate::detection::Detection;
|
||||||
|
|
||||||
|
pub struct Frame {
|
||||||
|
pub dims: (u32, u32),
|
||||||
|
pub detections: Vec<Detection>,
|
||||||
|
pub timestamp: f32, // in seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Frame {
|
||||||
|
#[inline]
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.detections.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn iter(&self) -> impl Iterator<Item = &Detection> {
|
||||||
|
self.detections.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.detections.is_empty()
|
||||||
|
}
|
||||||
|
}
|
97
src/lib.rs
Normal file
97
src/lib.rs
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
pub mod bbox;
|
||||||
|
pub mod detection;
|
||||||
|
pub mod detector;
|
||||||
|
pub mod error;
|
||||||
|
pub mod frame;
|
||||||
|
pub mod math;
|
||||||
|
pub mod rolling_avg;
|
||||||
|
pub mod scene;
|
||||||
|
pub mod tracker;
|
||||||
|
|
||||||
|
mod circular_queue;
|
||||||
|
mod predictor;
|
||||||
|
mod track;
|
||||||
|
|
||||||
|
pub use detection::Detection;
|
||||||
|
pub use frame::Frame;
|
||||||
|
pub use track::Track;
|
||||||
|
|
||||||
|
use error::Error;
|
||||||
|
use nalgebra as na;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::{fmt, rc::Rc};
|
||||||
|
|
||||||
|
pub trait Float:
|
||||||
|
num_traits::FromPrimitive + na::ComplexField + Copy + fmt::Debug + PartialEq + 'static
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Float for T where
|
||||||
|
T: num_traits::FromPrimitive + na::ComplexField + Copy + fmt::Debug + PartialEq + 'static
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Tracking {
|
||||||
|
fn predict(&mut self, src: &str);
|
||||||
|
fn update(&mut self, dets: &[Frame], src: &str) -> Result<(), error::Error>;
|
||||||
|
fn tracks(&self, src: &str) -> Rc<[Track]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct QuadraticTracker {
|
||||||
|
scenes: HashMap<String, scene::Scene>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QuadraticTracker {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
scenes: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for QuadraticTracker {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::Tracking for QuadraticTracker {
|
||||||
|
#[inline]
|
||||||
|
fn predict(&mut self, _src: &str) {}
|
||||||
|
|
||||||
|
fn update(&mut self, frames: &[Frame], src: &str) -> Result<(), Error> {
|
||||||
|
for frame in frames {
|
||||||
|
let item = self.scenes.get_mut(src);
|
||||||
|
let scene = if let Some(scene) = item {
|
||||||
|
scene
|
||||||
|
} else {
|
||||||
|
let padding = 10.0;
|
||||||
|
let (fw, fh) = frame.dims;
|
||||||
|
let poly = vec![
|
||||||
|
na::Point2::new(padding, padding),
|
||||||
|
na::Point2::new(fw as f32 - padding - padding, padding),
|
||||||
|
na::Point2::new(fw as f32 - padding - padding, fh as f32 - padding - padding),
|
||||||
|
na::Point2::new(padding, fh as f32 - padding - padding),
|
||||||
|
];
|
||||||
|
|
||||||
|
self.scenes
|
||||||
|
.entry(src.to_string())
|
||||||
|
.or_insert_with(|| scene::Scene::new(poly))
|
||||||
|
};
|
||||||
|
|
||||||
|
scene.update_time(frame.timestamp);
|
||||||
|
let mapping = scene.map_detections(frame.timestamp, &frame.detections);
|
||||||
|
scene.update(mapping);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn tracks(&self, src: &str) -> Rc<[Track]> {
|
||||||
|
if let Some(ctx) = self.scenes.get(src) {
|
||||||
|
return ctx.tracks().into_boxed_slice().into();
|
||||||
|
}
|
||||||
|
|
||||||
|
Rc::new([])
|
||||||
|
}
|
||||||
|
}
|
74
src/math.rs
Normal file
74
src/math.rs
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
use nalgebra as na;
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
|
pub fn linear_ls<T: na::ComplexField + Float>(
|
||||||
|
x: na::DVector<T>,
|
||||||
|
y: na::DVector<T>,
|
||||||
|
) -> Option<na::Matrix2x1<T>> {
|
||||||
|
let n = T::from(x.len()).unwrap();
|
||||||
|
|
||||||
|
let s_x = x.sum() + T::from(f32::EPSILON).unwrap();
|
||||||
|
let x2 = x.map(|x| x * x);
|
||||||
|
let s_x2 = x2.sum() + T::from(f32::EPSILON).unwrap();
|
||||||
|
let s_xy = x.zip_map(&y, |x, y| x * y).sum();
|
||||||
|
let s_y = y.sum();
|
||||||
|
|
||||||
|
let a = na::Matrix2::new(s_x2, s_x, s_x, n);
|
||||||
|
let b = na::Matrix2x1::new(s_xy, s_y);
|
||||||
|
|
||||||
|
let qr_result = a.qr();
|
||||||
|
let qty = qr_result.q().transpose() * b;
|
||||||
|
let beta_hat = qr_result.r().solve_upper_triangular(&qty)?;
|
||||||
|
|
||||||
|
Some(beta_hat)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quadratic_ls<T: na::ComplexField + Float>(
|
||||||
|
x: &na::DVector<T>,
|
||||||
|
y: &na::DVector<T>,
|
||||||
|
) -> std::option::Option<na::Matrix3x1<T>> {
|
||||||
|
let n = T::from(x.len()).unwrap();
|
||||||
|
|
||||||
|
let s_x1 = x.sum();
|
||||||
|
let x2 = x.map(|x| x * x);
|
||||||
|
let s_x2 = x2.sum();
|
||||||
|
let x3 = x2.zip_map(x, |a, b| a * b);
|
||||||
|
let s_x3 = x3.sum();
|
||||||
|
let x4 = x3.zip_map(x, |a, b| a * b);
|
||||||
|
let s_x4 = x4.sum();
|
||||||
|
let s_x2y = x2.zip_map(y, |x, y| x * y).sum();
|
||||||
|
let s_xy = x.zip_map(y, |x, y| x * y).sum();
|
||||||
|
let s_y = y.sum();
|
||||||
|
|
||||||
|
let a = na::Matrix3::new(s_x4, s_x3, s_x2, s_x3, s_x2, s_x1, s_x2, s_x1, n);
|
||||||
|
let b = na::Matrix3x1::new(s_x2y, s_xy, s_y);
|
||||||
|
|
||||||
|
let qr_result = a.qr();
|
||||||
|
let qty = qr_result.q().transpose() * b;
|
||||||
|
|
||||||
|
qr_result.r().solve_upper_triangular(&qty)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gauss(x: f32, c: f32) -> f32 {
|
||||||
|
(-((x * x) / (2.0 * c * c))).exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn approx_b<F, FN: FnOnce(F) -> F + Copy>(l: F, a: F, fx: FN) -> F
|
||||||
|
where
|
||||||
|
F: na::RealField + Float,
|
||||||
|
{
|
||||||
|
let mut dx = l;
|
||||||
|
|
||||||
|
let count = if dx < F::from(0.0001).unwrap() { 0 } else { 8 };
|
||||||
|
|
||||||
|
for _ in 0..count {
|
||||||
|
let b = a + dx;
|
||||||
|
let nl = na::distance(&na::Point2::new(a, fx(a)), &na::Point2::new(b, fx(b)));
|
||||||
|
let dl = l / nl;
|
||||||
|
|
||||||
|
dx *= dl;
|
||||||
|
}
|
||||||
|
|
||||||
|
a + dx
|
||||||
|
}
|
247
src/predictor.rs
Normal file
247
src/predictor.rs
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
use super::math::{approx_b, quadratic_ls};
|
||||||
|
use nalgebra as na;
|
||||||
|
use nalgebra::Normed;
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Predictor<F>
|
||||||
|
where
|
||||||
|
F: na::RealField + Float,
|
||||||
|
{
|
||||||
|
pub min_a: F,
|
||||||
|
pub curvature: F,
|
||||||
|
pub direction: na::Complex<F>,
|
||||||
|
pub extremum: na::Complex<F>,
|
||||||
|
pub mean: na::Complex<F>,
|
||||||
|
pub variance: na::Complex<F>,
|
||||||
|
pub has_linear: bool,
|
||||||
|
pub has_quadratic: bool,
|
||||||
|
pub use_quadratic: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Predictor<F>
|
||||||
|
where
|
||||||
|
F: na::RealField + Float,
|
||||||
|
{
|
||||||
|
pub fn new(use_quadratic: bool, min_a: F) -> Self {
|
||||||
|
Self {
|
||||||
|
min_a,
|
||||||
|
curvature: F::zero(),
|
||||||
|
direction: na::Complex::new(F::zero(), F::zero()),
|
||||||
|
extremum: na::Complex::new(F::zero(), F::zero()),
|
||||||
|
mean: na::Complex::new(F::zero(), F::zero()),
|
||||||
|
variance: na::Complex::new(F::zero(), F::zero()),
|
||||||
|
has_linear: false,
|
||||||
|
has_quadratic: false,
|
||||||
|
use_quadratic,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.curvature = F::zero();
|
||||||
|
self.direction = na::Complex::new(F::zero(), F::zero());
|
||||||
|
self.extremum = na::Complex::new(F::zero(), F::zero());
|
||||||
|
self.mean = na::Complex::new(F::zero(), F::zero());
|
||||||
|
self.variance = na::Complex::new(F::zero(), F::zero());
|
||||||
|
self.has_linear = false;
|
||||||
|
self.has_quadratic = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict_dist(&self, from: na::Point2<F>, dist: F) -> na::Point2<F> {
|
||||||
|
if self.has_quadratic {
|
||||||
|
let pt = na::Complex::new(from.x - self.extremum.re, from.y - self.extremum.im)
|
||||||
|
* self.direction.conj();
|
||||||
|
|
||||||
|
let fx = move |x| self.curvature * x * x;
|
||||||
|
let x = approx_b(dist, pt.re, fx);
|
||||||
|
let c = na::Complex::new(x, fx(x)) * self.direction + self.extremum;
|
||||||
|
|
||||||
|
na::Point2::new(c.re, c.im)
|
||||||
|
} else if self.has_linear {
|
||||||
|
let offset = self.direction * dist;
|
||||||
|
|
||||||
|
from + na::Vector2::new(offset.re, offset.im)
|
||||||
|
} else {
|
||||||
|
from
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn project_real(&self, pt: na::Complex<F>) -> (na::Complex<F>, F) {
|
||||||
|
let mut curr = na::Complex::new(pt.re, self.curvature * pt.re * pt.re);
|
||||||
|
let mut dist = pt.im - curr.im;
|
||||||
|
|
||||||
|
let epsilon = F::from(0.000001).unwrap();
|
||||||
|
let two = F::from(2).unwrap();
|
||||||
|
|
||||||
|
for _ in 0..8 {
|
||||||
|
println!();
|
||||||
|
|
||||||
|
let tan = two * (curr.im / (curr.re + epsilon));
|
||||||
|
let c = na::Unit::new_normalize(na::Complex::new(F::one(), tan)).into_inner();
|
||||||
|
let proj = (pt - curr) * c.conj();
|
||||||
|
|
||||||
|
dist = proj.im;
|
||||||
|
let proj = na::Complex::new(proj.re, F::zero()) * c + curr;
|
||||||
|
|
||||||
|
curr = na::Complex::new(proj.re, self.curvature * proj.re * proj.re);
|
||||||
|
}
|
||||||
|
|
||||||
|
(curr, dist)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn project_image(&self, pt: na::Complex<F>) -> (na::Complex<F>, F) {
|
||||||
|
let f = na::Complex::new(F::zero(), F::one() / self.curvature);
|
||||||
|
let d = pt - f;
|
||||||
|
let (a, b) = (d.im / d.re, f.im);
|
||||||
|
|
||||||
|
let four = F::from(4).unwrap();
|
||||||
|
let two = F::from(2).unwrap();
|
||||||
|
|
||||||
|
let p1 = -(Float::sqrt(a * a + four * self.curvature * b) - a) / (two * self.curvature);
|
||||||
|
let p2 = (Float::sqrt(a * a + four * self.curvature * b) + a) / (two * self.curvature);
|
||||||
|
|
||||||
|
let p1 = na::Complex::new(p1, self.curvature * p1 * p1);
|
||||||
|
let p2 = na::Complex::new(p2, self.curvature * p2 * p2);
|
||||||
|
|
||||||
|
let d1 = (p1 - pt).norm();
|
||||||
|
let d2 = (p2 - pt).norm();
|
||||||
|
|
||||||
|
if d1 < d2 {
|
||||||
|
(p1, -d1)
|
||||||
|
} else {
|
||||||
|
(p2, -d2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn project(&self, pt: na::Point2<F>) -> (na::Point2<F>, F) {
|
||||||
|
if self.has_quadratic {
|
||||||
|
let mut pt = na::Complex::new(pt.x - self.extremum.re, pt.y - self.extremum.im)
|
||||||
|
* self.direction.conj();
|
||||||
|
|
||||||
|
let epsilon = F::from(0.001).unwrap();
|
||||||
|
|
||||||
|
if Float::abs(pt.re) < epsilon {
|
||||||
|
if pt.re.is_negative() {
|
||||||
|
pt.re = -epsilon;
|
||||||
|
} else {
|
||||||
|
pt.re = epsilon;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (curr, dist) = if pt.im.is_sign_positive() == self.curvature.is_sign_positive() {
|
||||||
|
self.project_real(pt)
|
||||||
|
} else {
|
||||||
|
self.project_image(pt)
|
||||||
|
};
|
||||||
|
|
||||||
|
let curr = curr * self.direction + self.extremum;
|
||||||
|
|
||||||
|
(na::Point2::new(curr.re, curr.im), dist)
|
||||||
|
} else if self.has_linear {
|
||||||
|
let npt =
|
||||||
|
na::Complex::new(pt.x - self.mean.re, pt.y - self.mean.im) * self.direction.conj();
|
||||||
|
let dist = npt.im;
|
||||||
|
let npt = na::Complex::new(npt.re, F::zero()) * self.direction;
|
||||||
|
|
||||||
|
(
|
||||||
|
na::Point2::new(npt.re + self.mean.re, npt.im + self.mean.im),
|
||||||
|
dist,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
pt,
|
||||||
|
na::distance(&pt, &na::Point2::new(self.mean.re, self.mean.im)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update<'a, I: IntoIterator<Item = (&'a na::Point2<F>, F)>>(&mut self, points: I) {
|
||||||
|
self.has_linear = false;
|
||||||
|
self.has_quadratic = false;
|
||||||
|
self.curvature = F::zero();
|
||||||
|
|
||||||
|
// Calculating Mean
|
||||||
|
|
||||||
|
let iter = points.into_iter();
|
||||||
|
let mut x = Vec::with_capacity(iter.size_hint().0);
|
||||||
|
let mut y = Vec::with_capacity(iter.size_hint().0);
|
||||||
|
let mut w = Vec::with_capacity(iter.size_hint().0);
|
||||||
|
|
||||||
|
self.mean = na::Complex::new(F::zero(), F::zero());
|
||||||
|
self.variance = na::Complex::new(F::zero(), F::zero());
|
||||||
|
let mut wsum = F::zero();
|
||||||
|
let mut dir = na::Complex::new(F::zero(), F::zero());
|
||||||
|
let mut prev = na::Complex::new(F::zero(), F::zero());
|
||||||
|
|
||||||
|
for (idx, (p, w_)) in iter.enumerate() {
|
||||||
|
if idx == 0 {
|
||||||
|
prev = na::Complex::new(p.x, p.y);
|
||||||
|
} else {
|
||||||
|
let pt = na::Complex::new(p.x, p.y);
|
||||||
|
dir += na::Unit::new_normalize(prev - pt).into_inner();
|
||||||
|
prev = pt;
|
||||||
|
}
|
||||||
|
|
||||||
|
x.push(p.x);
|
||||||
|
y.push(p.y);
|
||||||
|
w.push(w_);
|
||||||
|
|
||||||
|
self.mean.re += p.x * w_;
|
||||||
|
self.mean.im += p.y * w_;
|
||||||
|
|
||||||
|
wsum += w_;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.mean /= wsum;
|
||||||
|
dir = na::Unit::new_normalize(dir).into_inner();
|
||||||
|
|
||||||
|
let n = x.len();
|
||||||
|
let (x, y): (na::DVector<F>, na::DVector<F>) = (x.into(), y.into());
|
||||||
|
let var = (x.variance(), y.variance());
|
||||||
|
|
||||||
|
if n < 3 || Float::max(var.0, var.1) < F::from(25).unwrap() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.direction = dir;
|
||||||
|
self.has_linear = true;
|
||||||
|
|
||||||
|
// Fitting Polyline
|
||||||
|
|
||||||
|
if !self.use_quadratic || x.len() < 8 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tr = self.direction.conj();
|
||||||
|
let (mut x, mut y) = (x, y);
|
||||||
|
// let qmean = na::Complex::new(x.mean(), y.mean());
|
||||||
|
|
||||||
|
x.iter_mut().zip(y.iter_mut()).for_each(|(x, y)| {
|
||||||
|
let c = (na::Complex::new(*x, *y) - self.mean) * tr;
|
||||||
|
|
||||||
|
*x = c.re;
|
||||||
|
*y = c.im;
|
||||||
|
});
|
||||||
|
|
||||||
|
if x.variance() < F::from(45).unwrap() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = if let Some(m) = quadratic_ls(&x, &y) {
|
||||||
|
m
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if Float::abs(params[0]) < self.min_a {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let x0 = -params[1] / (F::from(2).unwrap() * params[0]);
|
||||||
|
let y0 = params[0] * x0 * x0 + params[1] * x0 + params[2];
|
||||||
|
|
||||||
|
self.extremum = self.mean + na::Complex::new(x0, y0) * self.direction;
|
||||||
|
self.curvature = params[0] * F::from(0.75).unwrap();
|
||||||
|
self.has_quadratic = true;
|
||||||
|
}
|
||||||
|
}
|
102
src/rolling_avg.rs
Normal file
102
src/rolling_avg.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
use crate::circular_queue::CircularQueue;
|
||||||
|
use nalgebra as na;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RollingAvg {
|
||||||
|
min_dist: f32,
|
||||||
|
curr: (f32, na::Point2<f32>),
|
||||||
|
history: CircularQueue<(f32, na::Point2<f32>, f32)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RollingAvg {
|
||||||
|
pub fn new(min_dist: f32, hcount: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
min_dist,
|
||||||
|
curr: (0.0, na::Point2::new(0.0, 0.0)),
|
||||||
|
history: CircularQueue::with_capacity(hcount),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.curr = (0.0, na::Point2::new(0.0, 0.0));
|
||||||
|
self.history.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, ts: f32, pos: na::Point2<f32>) -> bool {
|
||||||
|
if na::distance(&self.curr.1, &pos) < self.min_dist && ts - self.curr.0 < 5. {
|
||||||
|
if let Some(top) = self.history.iter_mut().next() {
|
||||||
|
let weight = top.2;
|
||||||
|
top.0 = (top.0 * weight + ts) / (weight + 1.0);
|
||||||
|
top.1 = ((top.1.coords * weight + pos.coords) / (weight + 1.0)).into();
|
||||||
|
top.2 += 1.0;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.history.push((ts, pos, 1.0));
|
||||||
|
self.curr = (ts, pos);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let update = {
|
||||||
|
let mut iter = self.history.iter();
|
||||||
|
if let (Some(curr), Some(top), Some(prev)) = (iter.next(), iter.next(), iter.next())
|
||||||
|
{
|
||||||
|
let q1 = na::Unit::new_normalize(top.1.coords - curr.1.coords);
|
||||||
|
let q2 = na::Unit::new_normalize(prev.1.coords - top.1.coords);
|
||||||
|
|
||||||
|
q1.dot(&q2) < -0.0
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if update {
|
||||||
|
self.history.pop();
|
||||||
|
|
||||||
|
if let Some(top) = self.history.top_mut() {
|
||||||
|
let weight = top.2;
|
||||||
|
top.0 = (top.0 * weight + ts) / (weight + 1.0);
|
||||||
|
top.1 = ((top.1.coords * weight + pos.coords) / (weight + 1.0)).into();
|
||||||
|
top.2 += 1.0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn velocity(&self) -> Option<f32> {
|
||||||
|
let mut iter = self.history.iter();
|
||||||
|
|
||||||
|
let top = iter.next()?;
|
||||||
|
let prev = iter.next()?;
|
||||||
|
|
||||||
|
let dl = na::distance(&top.1, &prev.1);
|
||||||
|
let dt = top.0 - prev.0;
|
||||||
|
|
||||||
|
Some(dl / dt)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn top(&self) -> Option<&(f32, na::Point2<f32>, f32)> {
|
||||||
|
self.history.iter().next()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn iter(&self) -> impl Iterator<Item = &(f32, na::Point2<f32>, f32)> {
|
||||||
|
self.history.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn iter_points(&self) -> impl Iterator<Item = (&na::Point2<f32>, f32)> {
|
||||||
|
self.history.iter().map(|(_, x, w)| (x, *w))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn num_points(&self) -> usize {
|
||||||
|
self.history.len()
|
||||||
|
}
|
||||||
|
}
|
683
src/scene.rs
Normal file
683
src/scene.rs
Normal file
@ -0,0 +1,683 @@
|
|||||||
|
use std::sync::atomic::AtomicU32;
|
||||||
|
|
||||||
|
use crate::tracker::Object;
|
||||||
|
use crate::Detection;
|
||||||
|
|
||||||
|
use nalgebra as na;
|
||||||
|
|
||||||
|
use crate::circular_queue::CircularQueue;
|
||||||
|
use munkres::{solve_assignment, WeightMatrix, Weights};
|
||||||
|
|
||||||
|
static SEQ_ID: AtomicU32 = AtomicU32::new(1);
|
||||||
|
|
||||||
|
const SECONDS_IN_FRAME: f32 = 0.04;
|
||||||
|
const CONFIRM_SECONDS_RATIO: f32 = 0.4;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum IndexedSliceKind {
|
||||||
|
All,
|
||||||
|
Indexes(Vec<usize>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IndexedSlice<'a, T> {
|
||||||
|
pub slice: &'a [T],
|
||||||
|
kind: IndexedSliceKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Clone for IndexedSlice<'a, T> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
slice: self.slice,
|
||||||
|
kind: self.kind.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> IndexedSlice<'a, T> {
|
||||||
|
pub fn new(slice: &'a [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
slice,
|
||||||
|
kind: IndexedSliceKind::All,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_indexes(slice: &'a [T], idx: Vec<usize>) -> Self {
|
||||||
|
Self {
|
||||||
|
slice,
|
||||||
|
kind: IndexedSliceKind::Indexes(idx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn get_index(&self, idx: usize) -> usize {
|
||||||
|
match &self.kind {
|
||||||
|
IndexedSliceKind::All => idx,
|
||||||
|
IndexedSliceKind::Indexes(idxs) => idxs[idx],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[inline]
|
||||||
|
pub fn get(&self, idx: usize) -> Option<&'a T> {
|
||||||
|
match &self.kind {
|
||||||
|
IndexedSliceKind::All => self.slice.get(idx),
|
||||||
|
IndexedSliceKind::Indexes(idxs) => self.slice.get(*idxs.get(idx)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
match &self.kind {
|
||||||
|
IndexedSliceKind::All => self.slice.len(),
|
||||||
|
IndexedSliceKind::Indexes(idxs) => idxs.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
match &self.kind {
|
||||||
|
IndexedSliceKind::All => self.slice.is_empty(),
|
||||||
|
IndexedSliceKind::Indexes(idxs) => idxs.is_empty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn all_indexes(&self) -> Vec<usize> {
|
||||||
|
match &self.kind {
|
||||||
|
IndexedSliceKind::All => (0..self.slice.len()).collect(),
|
||||||
|
IndexedSliceKind::Indexes(idxs) => idxs.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> std::ops::Index<usize> for IndexedSlice<'a, T> {
|
||||||
|
type Output = T;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn index(&self, index: usize) -> &Self::Output {
|
||||||
|
let idx = match &self.kind {
|
||||||
|
IndexedSliceKind::All => index,
|
||||||
|
IndexedSliceKind::Indexes(idxs) => unsafe { *idxs.get_unchecked(index) },
|
||||||
|
};
|
||||||
|
|
||||||
|
&self.slice[idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum IndexedSliceIter<'a, T> {
|
||||||
|
All(std::slice::Iter<'a, T>),
|
||||||
|
Indexes((&'a [T], std::vec::IntoIter<usize>)),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for IndexedSliceIter<'a, T> {
|
||||||
|
type Item = &'a T;
|
||||||
|
fn next(&mut self) -> Option<&'a T> {
|
||||||
|
match self {
|
||||||
|
IndexedSliceIter::All(it) => it.next(),
|
||||||
|
IndexedSliceIter::Indexes((slice, it)) => slice.get(it.next()?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> IntoIterator for IndexedSlice<'a, T> {
|
||||||
|
type Item = &'a T;
|
||||||
|
type IntoIter = IndexedSliceIter<'a, T>;
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
match self.kind {
|
||||||
|
IndexedSliceKind::All => IndexedSliceIter::All(self.slice.iter()),
|
||||||
|
IndexedSliceKind::Indexes(idxs) => {
|
||||||
|
IndexedSliceIter::Indexes((self.slice, idxs.into_iter()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn in_bounds(p: na::Point2<f32>, poly: &[na::Point2<f32>]) -> bool {
|
||||||
|
let n = poly.len();
|
||||||
|
let mut inside = false;
|
||||||
|
let mut p1 = poly[0];
|
||||||
|
let mut xints = 0.0;
|
||||||
|
|
||||||
|
for i in 1..=n {
|
||||||
|
let p2 = poly[i % n];
|
||||||
|
|
||||||
|
if p.y > f32::min(p1.y, p2.y) && p.y <= f32::max(p1.y, p2.y) && p.x <= f32::max(p1.x, p2.x)
|
||||||
|
{
|
||||||
|
if (p1.y - p2.y).abs() > f32::EPSILON {
|
||||||
|
xints = (p.y - p1.y) * (p2.x - p1.x) / (p2.y - p1.y) + p1.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p1.x - p2.x).abs() < f32::EPSILON || p.x <= xints {
|
||||||
|
inside = !inside;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p1 = p2;
|
||||||
|
}
|
||||||
|
|
||||||
|
inside
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DetectionsMapping<'a> {
|
||||||
|
timestamp: f32,
|
||||||
|
detections: &'a [Detection],
|
||||||
|
cnd_matched: Vec<(usize, usize, f32)>,
|
||||||
|
veh_matched: Vec<(usize, usize, f32)>,
|
||||||
|
veh_missed: IndexedSlice<'a, Detection>,
|
||||||
|
|
||||||
|
ped_matched: Vec<(usize, usize, f32)>,
|
||||||
|
ped_missed: IndexedSlice<'a, Detection>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Accumulator {
|
||||||
|
pub ts: f32,
|
||||||
|
pub det: Detection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Accumulator {
|
||||||
|
pub fn new(ts: f32, det: Detection) -> Self {
|
||||||
|
Self { ts, det }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lerp_ts(&mut self, next_ts: f32, factor: f32) {
|
||||||
|
self.ts = self.ts * (1.0 - factor) + next_ts * factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lerp_w(&mut self, next_w: f32, factor: f32) {
|
||||||
|
self.det.w = self.det.w * (1.0 - factor) + next_w * factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lerp_h(&mut self, next_h: f32, factor: f32) {
|
||||||
|
self.det.h = self.det.h * (1.0 - factor) + next_h * factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lerp_x(&mut self, next_x: f32, factor: f32) {
|
||||||
|
self.det.x = self.det.x * (1.0 - factor) + next_x * factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lerp_y(&mut self, next_y: f32, factor: f32) {
|
||||||
|
self.det.y = self.det.y * (1.0 - factor) + next_y * factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Participant {
|
||||||
|
pub id: u32,
|
||||||
|
pub match_score_sum: f32,
|
||||||
|
pub hit_score_sum: f32,
|
||||||
|
pub hits_count: u32,
|
||||||
|
pub last_hit_score: f32,
|
||||||
|
pub last_update_diff: f32,
|
||||||
|
pub last_update_sec: f32,
|
||||||
|
pub time_since_update: f32,
|
||||||
|
pub object: Object,
|
||||||
|
pub class_votes: [u32; 12],
|
||||||
|
pub detections: CircularQueue<(f32, Detection)>,
|
||||||
|
pub accumulator: Accumulator,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Participant {
|
||||||
|
pub fn new(ts_sec: f32, det: &Detection) -> Self {
|
||||||
|
let mut detections = CircularQueue::with_capacity(16);
|
||||||
|
detections.push((ts_sec, *det));
|
||||||
|
|
||||||
|
let mut class_votes = [0; 12];
|
||||||
|
|
||||||
|
if det.class >= 0 && det.class < 12 {
|
||||||
|
class_votes[det.class as usize] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut object = Object::new(5.0, 0.01, true);
|
||||||
|
object.update(
|
||||||
|
0,
|
||||||
|
ts_sec,
|
||||||
|
na::Point2::new(det.x, det.y),
|
||||||
|
na::Point2::new(det.x, det.y),
|
||||||
|
);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
id: 0,
|
||||||
|
match_score_sum: det.confidence,
|
||||||
|
hit_score_sum: 0.0,
|
||||||
|
last_hit_score: 0.0,
|
||||||
|
hits_count: 1,
|
||||||
|
last_update_diff: 0.0,
|
||||||
|
time_since_update: 0.,
|
||||||
|
last_update_sec: ts_sec,
|
||||||
|
object,
|
||||||
|
class_votes,
|
||||||
|
detections,
|
||||||
|
accumulator: Accumulator::new(ts_sec, *det),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn upgrade(&mut self) {
|
||||||
|
self.id = SEQ_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn downgrade(&mut self) {
|
||||||
|
self.id = 0;
|
||||||
|
self.hit_score_sum = 0.0;
|
||||||
|
self.last_hit_score = 0.0;
|
||||||
|
self.hits_count = 0;
|
||||||
|
self.last_update_diff = 0.0;
|
||||||
|
self.match_score_sum = 0.0;
|
||||||
|
self.object.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_smothed(&mut self, score: f32, id: u32, det: &Detection) {
|
||||||
|
self.match_score_sum += self.accumulator.det.confidence;
|
||||||
|
self.hits_count += 1;
|
||||||
|
self.hit_score_sum += score;
|
||||||
|
self.last_hit_score = score;
|
||||||
|
|
||||||
|
if self.accumulator.det.class >= 0 && self.accumulator.det.class < 12 {
|
||||||
|
self.class_votes[self.accumulator.det.class as usize] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.last_update_diff = self.accumulator.ts - self.last_update_sec;
|
||||||
|
self.last_update_sec = self.accumulator.ts;
|
||||||
|
|
||||||
|
self.object.update(
|
||||||
|
id,
|
||||||
|
self.accumulator.ts,
|
||||||
|
na::Point2::new(self.accumulator.det.x, self.accumulator.det.y),
|
||||||
|
na::Point2::new(det.x, det.y),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, ts_sec: f32, det: &Detection, score: f32, id: u32) {
|
||||||
|
self.detections.push((ts_sec, *det));
|
||||||
|
|
||||||
|
let aw2 = self.accumulator.det.w / 2.0;
|
||||||
|
let ah2 = self.accumulator.det.h / 2.0;
|
||||||
|
|
||||||
|
let acc_l = self.accumulator.det.x - aw2;
|
||||||
|
let acc_r = self.accumulator.det.x + aw2;
|
||||||
|
let acc_t = self.accumulator.det.y - ah2;
|
||||||
|
let acc_b = self.accumulator.det.y + ah2;
|
||||||
|
|
||||||
|
self.accumulator.lerp_w(det.w, 0.05);
|
||||||
|
self.accumulator.lerp_h(det.h, 0.05);
|
||||||
|
|
||||||
|
let dw2 = det.w / 2.0;
|
||||||
|
let dh2 = det.h / 2.0;
|
||||||
|
|
||||||
|
let det_l = det.x - dw2;
|
||||||
|
let det_r = det.x + dw2;
|
||||||
|
let det_t = det.y - dh2;
|
||||||
|
let det_b = det.y + dh2;
|
||||||
|
|
||||||
|
let d_l = (det_l - acc_l).abs();
|
||||||
|
let d_r = (det_r - acc_r).abs();
|
||||||
|
let d_t = (det_t - acc_t).abs();
|
||||||
|
let d_b = (det_b - acc_b).abs();
|
||||||
|
|
||||||
|
let left = d_l < d_r;
|
||||||
|
let top = d_t < d_b;
|
||||||
|
|
||||||
|
let aw = self.accumulator.det.w;
|
||||||
|
let ah = self.accumulator.det.h;
|
||||||
|
|
||||||
|
let (x, y) = match (left, top) {
|
||||||
|
(true, true) => {
|
||||||
|
let cx = (acc_l + det_l) * 0.5 + aw * 0.5;
|
||||||
|
let cy = (acc_t + det_t) * 0.5 + ah * 0.5;
|
||||||
|
|
||||||
|
(cx, cy)
|
||||||
|
}
|
||||||
|
(true, false) => {
|
||||||
|
let cx = (acc_l + det_l) * 0.5 + aw * 0.5;
|
||||||
|
let cy = (acc_b + det_b) * 0.5 - ah * 0.5;
|
||||||
|
|
||||||
|
(cx, cy)
|
||||||
|
}
|
||||||
|
(false, true) => {
|
||||||
|
let cx = (acc_r + det_r) * 0.5 - aw * 0.5;
|
||||||
|
let cy = (acc_t + det_t) * 0.5 + ah * 0.5;
|
||||||
|
|
||||||
|
(cx, cy)
|
||||||
|
}
|
||||||
|
(false, false) => {
|
||||||
|
let cx = (acc_r + det_r) * 0.5 - aw * 0.5;
|
||||||
|
let cy = (acc_b + det_b) * 0.5 - ah * 0.5;
|
||||||
|
|
||||||
|
(cx, cy)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// println!("({}, {}) ({}, {})", det.x, det.y, x, y);
|
||||||
|
let p = 0.9;
|
||||||
|
self.accumulator.lerp_x(x, p);
|
||||||
|
self.accumulator.lerp_y(y, p);
|
||||||
|
self.accumulator.lerp_ts(ts_sec, 0.5 * p);
|
||||||
|
|
||||||
|
self.accumulator.det.confidence = score;
|
||||||
|
self.accumulator.det.class = det.class;
|
||||||
|
|
||||||
|
self.update_smothed(score, id, det);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn last_detection(&self) -> &Detection {
|
||||||
|
&self.detections.iter().next().unwrap().1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iou_slip(&self) -> f32 {
|
||||||
|
let mut detections = self.detections.iter();
|
||||||
|
let last_detection = detections.next().unwrap().1;
|
||||||
|
if let Some(detection) = detections.next() {
|
||||||
|
last_detection.iou(&detection.1)
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prediction(&self, ts_sec: f32) -> na::Point2<f32> {
|
||||||
|
self.object.predict(0, ts_sec)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn velocity(&self) -> &f32 {
|
||||||
|
&self.object.vel
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn direction(&self) -> &na::Complex<f32> {
|
||||||
|
&self.object.predictor.direction
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&Participant> for crate::Track {
|
||||||
|
fn from(p: &Participant) -> crate::Track {
|
||||||
|
crate::Track {
|
||||||
|
track_id: p.id as _,
|
||||||
|
time_since_update: p.time_since_update as _,
|
||||||
|
class: p
|
||||||
|
.class_votes
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by_key(|x| x.1)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.unwrap_or(0) as _,
|
||||||
|
confidence: p.hit_score_sum / p.hits_count as f32,
|
||||||
|
iou_slip: p.iou_slip(),
|
||||||
|
bbox: p.last_detection().bbox().as_xyah(),
|
||||||
|
velocity: Some(*p.velocity()),
|
||||||
|
direction: Some((p.direction().re, p.direction().im)),
|
||||||
|
curvature: Some(p.object.predictor.curvature),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Scene {
|
||||||
|
pub bounds: Vec<na::Point2<f32>>,
|
||||||
|
pub tracks: Vec<Participant>,
|
||||||
|
pub peds: Vec<Participant>,
|
||||||
|
confirm_seconds: f32,
|
||||||
|
last_second: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scene {
|
||||||
|
pub fn new(bounds: Vec<na::Point2<f32>>) -> Self {
|
||||||
|
Self {
|
||||||
|
bounds,
|
||||||
|
tracks: Vec::with_capacity(64),
|
||||||
|
peds: Vec::with_capacity(32),
|
||||||
|
confirm_seconds: SECONDS_IN_FRAME / CONFIRM_SECONDS_RATIO,
|
||||||
|
last_second: 0.,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assignment<'a>(
|
||||||
|
&self,
|
||||||
|
ts_sec: f32,
|
||||||
|
threshhold: f32,
|
||||||
|
dets: IndexedSlice<'a, Detection>,
|
||||||
|
objs: IndexedSlice<'_, Participant>,
|
||||||
|
) -> (Vec<(usize, usize, f32)>, IndexedSlice<'a, Detection>) {
|
||||||
|
let mut missed: Vec<_> = (0..dets.len()).collect();
|
||||||
|
|
||||||
|
let mut assignments = if !objs.is_empty() {
|
||||||
|
let n = dets.len().max(objs.len());
|
||||||
|
|
||||||
|
if n > 256 {
|
||||||
|
panic!("Confusion matrix is too big!");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut mat = WeightMatrix::from_fn(n, |(r, c)| {
|
||||||
|
if r < objs.len() && c < dets.len() {
|
||||||
|
let obj = &objs[r];
|
||||||
|
let det = &dets[c];
|
||||||
|
let pos = na::Point2::new(det.x, det.y);
|
||||||
|
|
||||||
|
1.0 - obj
|
||||||
|
.object
|
||||||
|
.probability(obj.id, ts_sec, pos, &obj.accumulator.det)
|
||||||
|
} else {
|
||||||
|
100000.0
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mat2 = mat.clone();
|
||||||
|
let res = solve_assignment(&mut mat);
|
||||||
|
|
||||||
|
if let Ok(inner) = res {
|
||||||
|
let mut assignments = Vec::new();
|
||||||
|
|
||||||
|
for i in inner {
|
||||||
|
if i.row < objs.len() && i.column < dets.len() {
|
||||||
|
let score = 1.0 - mat2.element_at(i);
|
||||||
|
|
||||||
|
if score > threshhold {
|
||||||
|
assignments.push((i.row, i.column, score));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
missed.retain(|&x| {
|
||||||
|
if assignments.iter().any(|&(_, p, _)| p == x) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// for i in 0..n {
|
||||||
|
// if mat2.element_at(munkres::Position { row: i, column: x }) < 0.80 {
|
||||||
|
// return false;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
true
|
||||||
|
});
|
||||||
|
|
||||||
|
assignments
|
||||||
|
} else {
|
||||||
|
println!("WARNING: assignement could not be solved!");
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
missed.iter_mut().for_each(|x| *x = dets.get_index(*x));
|
||||||
|
|
||||||
|
assignments.iter_mut().for_each(|(x, y, _)| {
|
||||||
|
*x = objs.get_index(*x);
|
||||||
|
*y = dets.get_index(*y);
|
||||||
|
});
|
||||||
|
|
||||||
|
(
|
||||||
|
assignments,
|
||||||
|
IndexedSlice::new_with_indexes(dets.slice, missed),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_time(&mut self, ts_sec: f32) {
|
||||||
|
if ts_sec > self.last_second {
|
||||||
|
self.confirm_seconds = (ts_sec - self.last_second) / CONFIRM_SECONDS_RATIO;
|
||||||
|
self.last_second = ts_sec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map_detections<'a>(
|
||||||
|
&self,
|
||||||
|
ts_sec: f32,
|
||||||
|
detections: &'a [Detection],
|
||||||
|
) -> DetectionsMapping<'a> {
|
||||||
|
let mut peds = Vec::new();
|
||||||
|
let mut vehicles = Vec::new();
|
||||||
|
for (idx, p) in detections.iter().enumerate() {
|
||||||
|
if p.class == 0 {
|
||||||
|
peds.push(idx);
|
||||||
|
} else {
|
||||||
|
vehicles.push(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut confirmed = Vec::new();
|
||||||
|
let mut pending = Vec::new();
|
||||||
|
let mut unconfirmed = Vec::new();
|
||||||
|
for (idx, p) in self.tracks.iter().enumerate() {
|
||||||
|
if p.id > 0 {
|
||||||
|
if ts_sec - p.last_update_sec < self.confirm_seconds {
|
||||||
|
confirmed.push(idx);
|
||||||
|
} else {
|
||||||
|
pending.push(idx);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unconfirmed.push(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let confirmed_tracks = IndexedSlice::new_with_indexes(&self.tracks, confirmed);
|
||||||
|
let veh_dets = IndexedSlice::new_with_indexes(detections, vehicles);
|
||||||
|
let (track_confirmed_matched, dets_confirmed_missed) =
|
||||||
|
self.assignment(ts_sec, 0.2, veh_dets, confirmed_tracks);
|
||||||
|
|
||||||
|
let pending_tracks = IndexedSlice::new_with_indexes(&self.tracks, pending);
|
||||||
|
let (track_pending_matched, dets_pending_missed) =
|
||||||
|
self.assignment(ts_sec, 0.2, dets_confirmed_missed, pending_tracks);
|
||||||
|
|
||||||
|
let unconfirmed_tracks = IndexedSlice::new_with_indexes(&self.tracks, unconfirmed);
|
||||||
|
let (unconfirmed_matched, dets_missed) =
|
||||||
|
self.assignment(ts_sec, 0.005, dets_pending_missed, unconfirmed_tracks);
|
||||||
|
|
||||||
|
let ped_tracks = IndexedSlice::new(&self.peds);
|
||||||
|
let ped_dets = IndexedSlice::new_with_indexes(detections, peds);
|
||||||
|
let (ped_matched, ped_missed) = self.assignment(ts_sec, 0.0, ped_dets, ped_tracks);
|
||||||
|
|
||||||
|
DetectionsMapping {
|
||||||
|
timestamp: ts_sec,
|
||||||
|
detections,
|
||||||
|
cnd_matched: unconfirmed_matched,
|
||||||
|
veh_matched: track_confirmed_matched
|
||||||
|
.into_iter()
|
||||||
|
.chain(track_pending_matched.into_iter())
|
||||||
|
.collect(),
|
||||||
|
veh_missed: dets_missed,
|
||||||
|
ped_matched,
|
||||||
|
ped_missed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, mapping: DetectionsMapping<'_>) {
|
||||||
|
static UPGRADE_HIT_SCORE_SUM_THRESHOLD: f32 = 12.0;
|
||||||
|
static DOWNGRADE_TIME_SINCE_UPDATE_THRESHOLD: f32 = 16.0;
|
||||||
|
static MAX_DURATION_MISSING_OBJECT: f32 = 2.0;
|
||||||
|
|
||||||
|
let time = mapping.timestamp;
|
||||||
|
let dets = mapping.detections;
|
||||||
|
|
||||||
|
for (i, j, score) in mapping.veh_matched {
|
||||||
|
let id = self.tracks[i].id;
|
||||||
|
self.tracks[i].update(time, &dets[j], score, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, j, score) in mapping.cnd_matched {
|
||||||
|
self.tracks[i].update(time, &dets[j], score, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for c in &mut self.tracks {
|
||||||
|
if c.id == 0
|
||||||
|
&& c.hit_score_sum > UPGRADE_HIT_SCORE_SUM_THRESHOLD
|
||||||
|
&& in_bounds(c.object.pos, &self.bounds)
|
||||||
|
{
|
||||||
|
c.upgrade();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for t in &mut self.tracks {
|
||||||
|
t.time_since_update = time - t.last_update_sec;
|
||||||
|
if t.id == 0 && t.time_since_update > 0.4 {
|
||||||
|
t.hit_score_sum *= 0.75;
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.id > 0
|
||||||
|
&& (t.time_since_update > DOWNGRADE_TIME_SINCE_UPDATE_THRESHOLD
|
||||||
|
|| t.object.predict_distance(t.id, time) > 400.0
|
||||||
|
|| !in_bounds(t.object.pos, &self.bounds)
|
||||||
|
|| !in_bounds(t.prediction(time), &self.bounds))
|
||||||
|
{
|
||||||
|
t.downgrade();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let bounds = self.bounds.clone();
|
||||||
|
|
||||||
|
self.tracks.retain(|t| {
|
||||||
|
if t.id == 0 && !in_bounds(t.prediction(time), &bounds) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
t.id > 0 || t.time_since_update < MAX_DURATION_MISSING_OBJECT
|
||||||
|
});
|
||||||
|
|
||||||
|
for det in mapping.veh_missed {
|
||||||
|
self.tracks.push(Participant::new(time, det));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, j, score) in mapping.ped_matched {
|
||||||
|
let id = self.peds[i].id;
|
||||||
|
self.peds[i].update(time, &dets[j], score, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
for c in &mut self.peds {
|
||||||
|
if c.id == 0
|
||||||
|
&& c.hit_score_sum > UPGRADE_HIT_SCORE_SUM_THRESHOLD
|
||||||
|
&& in_bounds(c.object.pos, &self.bounds)
|
||||||
|
{
|
||||||
|
c.upgrade();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for t in &mut self.peds {
|
||||||
|
let dt = time - t.last_update_sec;
|
||||||
|
|
||||||
|
if t.id > 0
|
||||||
|
&& (dt > DOWNGRADE_TIME_SINCE_UPDATE_THRESHOLD
|
||||||
|
|| !in_bounds(t.object.pos, &self.bounds)
|
||||||
|
|| !in_bounds(t.prediction(time), &self.bounds))
|
||||||
|
{
|
||||||
|
t.downgrade();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.peds.retain(|t| {
|
||||||
|
let dt = time - t.last_update_sec;
|
||||||
|
|
||||||
|
t.id > 0 || dt < MAX_DURATION_MISSING_OBJECT
|
||||||
|
});
|
||||||
|
|
||||||
|
for det in mapping.ped_missed {
|
||||||
|
self.peds.push(Participant::new(time, det));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tracks(&self) -> Vec<crate::Track> {
|
||||||
|
self.tracks
|
||||||
|
.iter()
|
||||||
|
.chain(self.peds.iter())
|
||||||
|
.filter(|t| t.id > 0 && t.time_since_update < self.confirm_seconds)
|
||||||
|
.map(Into::into)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
20
src/track.rs
Normal file
20
src/track.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use crate::bbox::{BBox, Xyah};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Track {
|
||||||
|
pub track_id: i32,
|
||||||
|
pub time_since_update: i32,
|
||||||
|
pub class: i32,
|
||||||
|
pub confidence: f32,
|
||||||
|
pub iou_slip: f32,
|
||||||
|
pub bbox: BBox<Xyah>,
|
||||||
|
|
||||||
|
// in px
|
||||||
|
pub velocity: Option<f32>,
|
||||||
|
|
||||||
|
// (x,y)
|
||||||
|
pub direction: Option<(f32, f32)>,
|
||||||
|
|
||||||
|
// a-coeff from parabolic curve fitting for this track's trajectory
|
||||||
|
pub curvature: Option<f32>,
|
||||||
|
}
|
121
src/tracker.rs
Normal file
121
src/tracker.rs
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
use super::math;
|
||||||
|
use super::{predictor::Predictor, rolling_avg::RollingAvg};
|
||||||
|
use crate::Detection;
|
||||||
|
use nalgebra as na;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Object {
|
||||||
|
pub predictor: Predictor<f32>,
|
||||||
|
pub history: RollingAvg,
|
||||||
|
pub prev_vel: f32,
|
||||||
|
pub vel: f32,
|
||||||
|
pub vel_updated_at: f32,
|
||||||
|
pub pos: na::Point2<f32>,
|
||||||
|
pub ts: f32,
|
||||||
|
pub initialized: bool,
|
||||||
|
pub correction: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Object {
|
||||||
|
pub fn new(md: f32, correction: f32, use_quadratic: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
predictor: Predictor::new(use_quadratic, 0.00025),
|
||||||
|
history: RollingAvg::new(md, (80.0 / md).round() as usize),
|
||||||
|
pos: na::Point2::new(0.0, 0.0),
|
||||||
|
prev_vel: 0.0,
|
||||||
|
vel_updated_at: 0.0,
|
||||||
|
vel: 0.0,
|
||||||
|
ts: 0.0,
|
||||||
|
initialized: false,
|
||||||
|
correction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.predictor.reset();
|
||||||
|
self.pos = na::Point2::new(0.0, 0.0);
|
||||||
|
self.vel_updated_at = 0.0;
|
||||||
|
self.vel = 0.0;
|
||||||
|
self.prev_vel = 0.0;
|
||||||
|
self.ts = 0.0;
|
||||||
|
self.initialized = false;
|
||||||
|
self.history.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn predict_distance(&self, _id: u32, ts: f32) -> f32 {
|
||||||
|
let pred_dist = self.vel * (ts - self.ts).max(0.033).abs();
|
||||||
|
|
||||||
|
pred_dist + self.vel * self.correction
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn predict(&self, id: u32, ts: f32) -> na::Point2<f32> {
|
||||||
|
self.predictor
|
||||||
|
.predict_dist(self.pos, self.predict_distance(id, ts))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn probability(&self, id: u32, ts: f32, pp: na::Point2<f32>, avg: &Detection) -> f32 {
|
||||||
|
let size = avg.w.min(avg.h);
|
||||||
|
let pred_dist = self.predict_distance(id, ts);
|
||||||
|
|
||||||
|
let (pt, d) = self.predictor.project(pp);
|
||||||
|
let (dir, mut dist) = na::Unit::new_and_get(pt - self.pos);
|
||||||
|
let dir = na::Complex::new(dir.x, dir.y);
|
||||||
|
let dt = dir * self.predictor.direction.conj();
|
||||||
|
|
||||||
|
if dt.re < 0.0 {
|
||||||
|
dist = -dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
let gap = pred_dist * 0.5 + size * 0.6;
|
||||||
|
let angle = math::gauss(d.abs(), (pred_dist / 4.0).max(17.0));
|
||||||
|
let speed = math::gauss((pred_dist - dist.abs()).abs(), gap);
|
||||||
|
|
||||||
|
angle * speed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, id: u32, ts: f32, pp: na::Point2<f32>, rp: na::Point2<f32>) {
|
||||||
|
if self.history.push(ts, rp) {
|
||||||
|
self.predictor.update(self.history.iter_points());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(vel) = self.history.velocity() {
|
||||||
|
self.prev_vel = self.vel;
|
||||||
|
|
||||||
|
if self.initialized {
|
||||||
|
// self.vel = self.vel * 0.9 + vel * 0.1;
|
||||||
|
|
||||||
|
let dt = ts - self.vel_updated_at;
|
||||||
|
let dvel = (vel - self.vel) / dt;
|
||||||
|
let dvel = dvel.clamp(-10_000.0, 10_000.0);
|
||||||
|
|
||||||
|
// println!("{}: {}", id, dvel);
|
||||||
|
|
||||||
|
self.vel += (dvel * dt) * 0.2;
|
||||||
|
self.vel_updated_at = ts;
|
||||||
|
} else {
|
||||||
|
self.vel = vel;
|
||||||
|
self.initialized = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if self.vel < 5.0 {
|
||||||
|
// self.predictor.has_linear = false;
|
||||||
|
// self.predictor.has_quadratic = false;
|
||||||
|
// }
|
||||||
|
|
||||||
|
let (pos, _) = self.predictor.project(pp);
|
||||||
|
|
||||||
|
if self.predictor.has_quadratic || self.predictor.has_linear {
|
||||||
|
self.pos = (pos.coords * 0.2 + self.predict(id, ts).coords * 0.8).into();
|
||||||
|
} else if self.initialized {
|
||||||
|
self.pos = (self.pos.coords * 0.9 + pos.coords * 0.1).into();
|
||||||
|
} else {
|
||||||
|
self.pos = pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ts = ts;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user