From 2ebfe7da2ca1bef99a0afb785b36495484226d85 Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Thu, 17 Mar 2022 15:45:35 +0400 Subject: [PATCH] Initial commit --- .drone.yml | 14 +++ .gitignore | 24 +++++ Cargo.toml | 14 +++ README.md | 2 + src/detection.rs | 58 +++++++++++ src/detector.rs | 248 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 5 + 7 files changed, 365 insertions(+) create mode 100644 .drone.yml create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 src/detection.rs create mode 100644 src/detector.rs create mode 100644 src/lib.rs diff --git a/.drone.yml b/.drone.yml new file mode 100644 index 0000000..a449d87 --- /dev/null +++ b/.drone.yml @@ -0,0 +1,14 @@ +kind: pipeline +name: default + +steps: +- name: build + image: hub.aidev.ru/rust-onnxruntime:latest + commands: + - cargo build --verbose --all + +- name: fmt-check + image: hub.aidev.ru/rust-onnxruntime:latest + commands: + - rustup component add rustfmt + - cargo fmt --all -- --check diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f95942 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# ---> Rust +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + + + +# Added by cargo +# +# already existing elements were commented out + +/target +#Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..55f820e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "object-detector" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = "1.0" +serde_derive = "1.0" +ndarray = { version = "0.15" } +onnx-model = {git = "https://git.aidev.ru/andrey/onnx-model.git", branch="v1.10"} + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..98a722a --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# object-detector + diff --git a/src/detection.rs b/src/detection.rs new file mode 100644 index 0000000..9277f24 --- /dev/null +++ b/src/detection.rs @@ -0,0 +1,58 @@ +use serde_derive::{Deserialize, Serialize}; + +// use crate::bbox::{BBox, Xywh}; + +/// Contains (x,y) of the center and (width,height) of bbox +#[derive(Serialize, Deserialize, Debug, Clone, Copy)] +pub struct Detection { + pub x: f32, + pub y: f32, + pub w: f32, + pub h: f32, + #[serde(rename = "p")] + pub confidence: f32, + #[serde(rename = "c")] + pub class: i32, +} + +impl Detection { + pub fn iou(&self, other: &Detection) -> f32 { + let b1_area = (self.w + 1.) * (self.h + 1.); + let (xmin, xmax, ymin, ymax) = (self.xmin(), self.xmax(), self.ymin(), self.ymax()); + + let b2_area = (other.w + 1.) * (other.h + 1.); + + let i_xmin = xmin.max(other.xmin()); + let i_xmax = xmax.min(other.xmax()); + let i_ymin = ymin.max(other.ymin()); + let i_ymax = ymax.min(other.ymax()); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + + (i_area) / (b1_area + b2_area - i_area) + } + + // #[inline(always)] + // pub fn bbox(&self) -> BBox { + // BBox::xywh(self.x, self.y, self.w, self.h) + // } + + #[inline(always)] + pub fn xmax(&self) -> f32 { + self.x + self.w / 2. + } + + #[inline(always)] + pub fn ymax(&self) -> f32 { + self.y + self.h / 2. + } + + #[inline(always)] + pub fn xmin(&self) -> f32 { + self.x - self.w / 2. + } + + #[inline(always)] + pub fn ymin(&self) -> f32 { + self.y - self.h / 2. + } +} diff --git a/src/detector.rs b/src/detector.rs new file mode 100644 index 0000000..eb4ea32 --- /dev/null +++ b/src/detector.rs @@ -0,0 +1,248 @@ +use crate::detection::Detection; + +use ndarray::prelude::*; +use onnx_model::error::Error; +use onnx_model::*; + +const MODEL_DYNAMIC_INPUT_DIMENSION: i64 = -1; + +pub struct YoloDetectorConfig { + pub confidence_threshold: f32, + pub iou_threshold: f32, + pub classes: Vec, + pub class_map: Option>, +} + +impl YoloDetectorConfig { + pub fn new(confidence_threshold: f32, classes: Vec) -> Self { + Self { + confidence_threshold, + iou_threshold: 0.2, + classes, + class_map: None, + } + } +} + +pub struct YoloDetector { + model: OnnxInferenceModel, + config: YoloDetectorConfig, +} + +impl YoloDetector { + pub fn new( + model_src: &str, + config: YoloDetectorConfig, + device: InferenceDevice, + ) -> Result { + let model = OnnxInferenceModel::new(model_src, device)?; + + Ok(Self { model, config }) + } + + pub fn get_model_input_size(&self) -> Option<(u32, u32)> { + let mut input_dims = self.model.get_input_infos()[0].shape.dims.clone(); + let input_height = input_dims.pop().unwrap(); + let input_width = input_dims.pop().unwrap(); + if input_height == MODEL_DYNAMIC_INPUT_DIMENSION + && input_width == MODEL_DYNAMIC_INPUT_DIMENSION + { + None + } else { + Some((input_width as u32, input_height as u32)) + } + } + + pub fn detect( + &self, + frames: ArrayView4<'_, f32>, + fw: i32, + fh: i32, + with_crop: bool, + ) -> Result>, Error> { + let in_shape = frames.shape(); + let (in_w, in_h) = (in_shape[3], in_shape[2]); + let preditions = self.model.run(&[frames.into_dyn()])?.pop().unwrap(); + let shape = preditions.shape(); + let shape = [shape[0], shape[1], shape[2]]; + let arr = preditions.into_shape(shape).unwrap(); + let bboxes = self.postprocess(arr.view(), in_w, in_h, fw, fh, with_crop)?; + + Ok(bboxes) + } + + fn postprocess( + &self, + view: ArrayView3<'_, f32>, + in_w: usize, + in_h: usize, + frame_width: i32, + frame_height: i32, + with_crop: bool, + ) -> Result>, Error> { + let shape = view.shape(); + let nbatches = shape[0]; + let npreds = shape[1]; + let pred_size = shape[2]; + let mut results: Vec> = (0..nbatches).map(|_| vec![]).collect(); + + let (ox, oy, ow, oh) = if with_crop { + let in_a = in_h as f32 / in_w as f32; + let frame_a = frame_height as f32 / frame_width as f32; + + if in_a > frame_a { + let w = frame_height as f32 / in_a; + ((frame_width as f32 - w) / 2.0, 0.0, w, frame_height as f32) + } else { + let h = frame_width as f32 * in_a; + (0.0, (frame_height as f32 - h) / 2.0, frame_width as f32, h) + } + } else { + (0.0, 0.0, frame_width as f32, frame_height as f32) + }; + + // Extract the bounding boxes for which confidence is above the threshold. + for batch in 0..nbatches { + let results = &mut results[batch]; + + // The bounding boxes grouped by (maximum) class index. + let mut bboxes: Vec> = (0..80).map(|_| vec![]).collect(); + + for index in 0..npreds { + let x_0 = view.index_axis(Axis(0), batch); + let x_1 = x_0.index_axis(Axis(0), index); + let detection = x_1.as_slice().unwrap(); + + let (x, y, w, h) = match &detection[0..4] { + [center_x, center_y, width, height] => { + let center_x = ox + center_x * ow; + let center_y = oy + center_y * oh; + let width = width * ow as f32; + let height = height * oh as f32; + + (center_x, center_y, width, height) + } + + _ => unreachable!(), + }; + + let classes = &detection[4..pred_size]; + + let mut class_index = -1; + let mut confidence = 0.0; + + for (idx, val) in classes.iter().copied().enumerate() { + if val > confidence { + class_index = idx as i32; + confidence = val; + } + } + + if class_index > -1 && confidence > self.config.confidence_threshold { + if !self.config.classes.contains(&class_index) { + continue; + } + + if w * h > ((frame_width / 2) * (frame_height / 2)) as f32 { + continue; + } + + let mapped_class = match &self.config.class_map { + Some(map) => map + .get(class_index as usize) + .copied() + .unwrap_or(class_index), + None => class_index, + }; + + bboxes[mapped_class as usize].push(Detection { + x, + y, + w, + h, + confidence, + class: class_index as _, + }); + } + } + + for mut dets in bboxes.into_iter() { + if dets.is_empty() { + continue; + } + + if dets.len() == 1 { + results.append(&mut dets); + continue; + } + + let indices = self.non_maximum_supression(&mut dets)?; + + results.extend(dets.drain(..).enumerate().filter_map(|(idx, item)| { + if indices.contains(&(idx as i32)) { + Some(item) + } else { + None + } + })); + + // for (det, idx) in dets.into_iter().zip(indices) { + // if idx > -1 { + // results.push(det); + // } + // } + } + } + + Ok(results) + } + + // fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result, Error> { + // let mut rects = core::Vector::new(); + // let mut scores = core::Vector::new(); + + // for det in dets { + // rects.push(core::Rect2d::new( + // det.xmin as f64, + // det.ymin as f64, + // (det.xmax - det.xmin) as f64, + // (det.ymax - det.ymin) as f64 + // )); + // scores.push(det.confidence); + // } + + // let mut indices = core::Vector::::new(); + // dnn::nms_boxes_f64( + // &rects, + // &scores, + // self.config.confidence_threshold, + // self.config.iou_threshold, + // &mut indices, + // 1.0, + // 0 + // )?; + + // Ok(indices.to_vec()) + // } + + fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result, Error> { + dets.sort_unstable_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap()); + + let mut retain: Vec<_> = (0..dets.len() as i32).collect(); + for idx in 0..dets.len() - 1 { + if retain[idx] != -1 { + for r in retain[idx + 1..].iter_mut() { + if *r != -1 { + let iou = dets[idx].iou(&dets[*r as usize]); + if iou > self.config.iou_threshold { + *r = -1; + } + } + } + } + } + + retain.retain(|&x| x > -1); + Ok(retain) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..109ed7a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,5 @@ +mod detection; +mod detector; + +pub use detection::*; +pub use detector::*;