Initial commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Andrey Tkachenko 2022-03-17 15:45:35 +04:00
commit 2ebfe7da2c
7 changed files with 365 additions and 0 deletions

14
.drone.yml Normal file
View File

@ -0,0 +1,14 @@
kind: pipeline
name: default
steps:
- name: build
image: hub.aidev.ru/rust-onnxruntime:latest
commands:
- cargo build --verbose --all
- name: fmt-check
image: hub.aidev.ru/rust-onnxruntime:latest
commands:
- rustup component add rustfmt
- cargo fmt --all -- --check

24
.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# ---> Rust
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
# Added by cargo
#
# already existing elements were commented out
/target
#Cargo.lock

14
Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "object-detector"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serde = "1.0"
serde_derive = "1.0"
ndarray = { version = "0.15" }
onnx-model = {git = "https://git.aidev.ru/andrey/onnx-model.git", branch="v1.10"}

2
README.md Normal file
View File

@ -0,0 +1,2 @@
# object-detector

58
src/detection.rs Normal file
View File

@ -0,0 +1,58 @@
use serde_derive::{Deserialize, Serialize};
// use crate::bbox::{BBox, Xywh};
/// Contains (x,y) of the center and (width,height) of bbox
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub struct Detection {
pub x: f32,
pub y: f32,
pub w: f32,
pub h: f32,
#[serde(rename = "p")]
pub confidence: f32,
#[serde(rename = "c")]
pub class: i32,
}
impl Detection {
pub fn iou(&self, other: &Detection) -> f32 {
let b1_area = (self.w + 1.) * (self.h + 1.);
let (xmin, xmax, ymin, ymax) = (self.xmin(), self.xmax(), self.ymin(), self.ymax());
let b2_area = (other.w + 1.) * (other.h + 1.);
let i_xmin = xmin.max(other.xmin());
let i_xmax = xmax.min(other.xmax());
let i_ymin = ymin.max(other.ymin());
let i_ymax = ymax.min(other.ymax());
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
(i_area) / (b1_area + b2_area - i_area)
}
// #[inline(always)]
// pub fn bbox(&self) -> BBox<Xywh> {
// BBox::xywh(self.x, self.y, self.w, self.h)
// }
#[inline(always)]
pub fn xmax(&self) -> f32 {
self.x + self.w / 2.
}
#[inline(always)]
pub fn ymax(&self) -> f32 {
self.y + self.h / 2.
}
#[inline(always)]
pub fn xmin(&self) -> f32 {
self.x - self.w / 2.
}
#[inline(always)]
pub fn ymin(&self) -> f32 {
self.y - self.h / 2.
}
}

248
src/detector.rs Normal file
View File

@ -0,0 +1,248 @@
use crate::detection::Detection;
use ndarray::prelude::*;
use onnx_model::error::Error;
use onnx_model::*;
const MODEL_DYNAMIC_INPUT_DIMENSION: i64 = -1;
pub struct YoloDetectorConfig {
pub confidence_threshold: f32,
pub iou_threshold: f32,
pub classes: Vec<i32>,
pub class_map: Option<Vec<i32>>,
}
impl YoloDetectorConfig {
pub fn new(confidence_threshold: f32, classes: Vec<i32>) -> Self {
Self {
confidence_threshold,
iou_threshold: 0.2,
classes,
class_map: None,
}
}
}
pub struct YoloDetector {
model: OnnxInferenceModel,
config: YoloDetectorConfig,
}
impl YoloDetector {
pub fn new(
model_src: &str,
config: YoloDetectorConfig,
device: InferenceDevice,
) -> Result<Self, Error> {
let model = OnnxInferenceModel::new(model_src, device)?;
Ok(Self { model, config })
}
pub fn get_model_input_size(&self) -> Option<(u32, u32)> {
let mut input_dims = self.model.get_input_infos()[0].shape.dims.clone();
let input_height = input_dims.pop().unwrap();
let input_width = input_dims.pop().unwrap();
if input_height == MODEL_DYNAMIC_INPUT_DIMENSION
&& input_width == MODEL_DYNAMIC_INPUT_DIMENSION
{
None
} else {
Some((input_width as u32, input_height as u32))
}
}
pub fn detect(
&self,
frames: ArrayView4<'_, f32>,
fw: i32,
fh: i32,
with_crop: bool,
) -> Result<Vec<Vec<Detection>>, Error> {
let in_shape = frames.shape();
let (in_w, in_h) = (in_shape[3], in_shape[2]);
let preditions = self.model.run(&[frames.into_dyn()])?.pop().unwrap();
let shape = preditions.shape();
let shape = [shape[0], shape[1], shape[2]];
let arr = preditions.into_shape(shape).unwrap();
let bboxes = self.postprocess(arr.view(), in_w, in_h, fw, fh, with_crop)?;
Ok(bboxes)
}
fn postprocess(
&self,
view: ArrayView3<'_, f32>,
in_w: usize,
in_h: usize,
frame_width: i32,
frame_height: i32,
with_crop: bool,
) -> Result<Vec<Vec<Detection>>, Error> {
let shape = view.shape();
let nbatches = shape[0];
let npreds = shape[1];
let pred_size = shape[2];
let mut results: Vec<Vec<Detection>> = (0..nbatches).map(|_| vec![]).collect();
let (ox, oy, ow, oh) = if with_crop {
let in_a = in_h as f32 / in_w as f32;
let frame_a = frame_height as f32 / frame_width as f32;
if in_a > frame_a {
let w = frame_height as f32 / in_a;
((frame_width as f32 - w) / 2.0, 0.0, w, frame_height as f32)
} else {
let h = frame_width as f32 * in_a;
(0.0, (frame_height as f32 - h) / 2.0, frame_width as f32, h)
}
} else {
(0.0, 0.0, frame_width as f32, frame_height as f32)
};
// Extract the bounding boxes for which confidence is above the threshold.
for batch in 0..nbatches {
let results = &mut results[batch];
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Detection>> = (0..80).map(|_| vec![]).collect();
for index in 0..npreds {
let x_0 = view.index_axis(Axis(0), batch);
let x_1 = x_0.index_axis(Axis(0), index);
let detection = x_1.as_slice().unwrap();
let (x, y, w, h) = match &detection[0..4] {
[center_x, center_y, width, height] => {
let center_x = ox + center_x * ow;
let center_y = oy + center_y * oh;
let width = width * ow as f32;
let height = height * oh as f32;
(center_x, center_y, width, height)
}
_ => unreachable!(),
};
let classes = &detection[4..pred_size];
let mut class_index = -1;
let mut confidence = 0.0;
for (idx, val) in classes.iter().copied().enumerate() {
if val > confidence {
class_index = idx as i32;
confidence = val;
}
}
if class_index > -1 && confidence > self.config.confidence_threshold {
if !self.config.classes.contains(&class_index) {
continue;
}
if w * h > ((frame_width / 2) * (frame_height / 2)) as f32 {
continue;
}
let mapped_class = match &self.config.class_map {
Some(map) => map
.get(class_index as usize)
.copied()
.unwrap_or(class_index),
None => class_index,
};
bboxes[mapped_class as usize].push(Detection {
x,
y,
w,
h,
confidence,
class: class_index as _,
});
}
}
for mut dets in bboxes.into_iter() {
if dets.is_empty() {
continue;
}
if dets.len() == 1 {
results.append(&mut dets);
continue;
}
let indices = self.non_maximum_supression(&mut dets)?;
results.extend(dets.drain(..).enumerate().filter_map(|(idx, item)| {
if indices.contains(&(idx as i32)) {
Some(item)
} else {
None
}
}));
// for (det, idx) in dets.into_iter().zip(indices) {
// if idx > -1 {
// results.push(det);
// }
// }
}
}
Ok(results)
}
// fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result<Vec<i32>, Error> {
// let mut rects = core::Vector::new();
// let mut scores = core::Vector::new();
// for det in dets {
// rects.push(core::Rect2d::new(
// det.xmin as f64,
// det.ymin as f64,
// (det.xmax - det.xmin) as f64,
// (det.ymax - det.ymin) as f64
// ));
// scores.push(det.confidence);
// }
// let mut indices = core::Vector::<i32>::new();
// dnn::nms_boxes_f64(
// &rects,
// &scores,
// self.config.confidence_threshold,
// self.config.iou_threshold,
// &mut indices,
// 1.0,
// 0
// )?;
// Ok(indices.to_vec())
// }
fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result<Vec<i32>, Error> {
dets.sort_unstable_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let mut retain: Vec<_> = (0..dets.len() as i32).collect();
for idx in 0..dets.len() - 1 {
if retain[idx] != -1 {
for r in retain[idx + 1..].iter_mut() {
if *r != -1 {
let iou = dets[idx].iou(&dets[*r as usize]);
if iou > self.config.iou_threshold {
*r = -1;
}
}
}
}
}
retain.retain(|&x| x > -1);
Ok(retain)
}
}

5
src/lib.rs Normal file
View File

@ -0,0 +1,5 @@
mod detection;
mod detector;
pub use detection::*;
pub use detector::*;