diff --git a/Cargo.lock b/Cargo.lock index 68f5ae8..3f7dd15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,17 @@ dependencies = [ "autocfg", ] +[[package]] +name = "object-detector" +version = "0.1.0" +source = "git+https://git.aidev.ru/andrey/object-detector.git#2ebfe7da2ca1bef99a0afb785b36495484226d85" +dependencies = [ + "ndarray 0.15.3", + "onnx-model", + "serde", + "serde_derive", +] + [[package]] name = "onnx-model" version = "0.2.3" @@ -523,6 +534,7 @@ dependencies = [ "nalgebra", "ndarray 0.15.3", "num-traits", + "object-detector", "onnx-model", "serde", "serde_derive", diff --git a/Cargo.toml b/Cargo.toml index 4d2f802..bcf6bff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,5 +10,6 @@ num-traits = "0.2" serde = "1.0" serde_derive = "1.0" thiserror = "1.0" +object-detector = { git = "https://git.aidev.ru/andrey/object-detector.git" } munkres = { version = "0.5", git = "https://git.aidev.ru/andrey/munkres-rs.git" } onnx-model = { git = "https://git.aidev.ru/andrey/onnx-model.git", branch = "v1.10" } diff --git a/src/bbox.rs b/src/bbox.rs index d70afb4..35be7a2 100644 --- a/src/bbox.rs +++ b/src/bbox.rs @@ -1,3 +1,4 @@ +use crate::Detection; use serde::{Deserialize, Serialize}; use serde_derive::{Deserialize, Serialize}; use std::marker::PhantomData; @@ -280,12 +281,8 @@ impl<'a> From<&'a BBox> for BBox { } } -impl<'a> From<&'a BBox> for BBox { - #[inline] - fn from(v: &'a BBox) -> Self { - Self( - [v.0[0], v.0[1], v.0[2] * v.0[3], v.0[3]], - Default::default(), - ) +impl From<&'_ Detection> for BBox { + fn from(det: &'_ Detection) -> BBox { + BBox::xywh(det.x, det.y, det.w, det.h) } } diff --git a/src/detection.rs b/src/detection.rs deleted file mode 100644 index e0793bb..0000000 --- a/src/detection.rs +++ /dev/null @@ -1,58 +0,0 @@ -use serde_derive::{Deserialize, Serialize}; - -use crate::bbox::{BBox, Xywh}; - -/// Contains (x,y) of the center and (width,height) of bbox -#[derive(Serialize, Deserialize, Debug, Clone, Copy)] -pub struct Detection { - pub x: f32, - pub y: f32, - pub w: f32, - pub h: f32, - #[serde(rename = "p")] - pub confidence: f32, - #[serde(rename = "c")] - pub class: i32, -} - -impl Detection { - pub fn iou(&self, other: &Detection) -> f32 { - let b1_area = (self.w + 1.) * (self.h + 1.); - let (xmin, xmax, ymin, ymax) = (self.xmin(), self.xmax(), self.ymin(), self.ymax()); - - let b2_area = (other.w + 1.) * (other.h + 1.); - - let i_xmin = xmin.max(other.xmin()); - let i_xmax = xmax.min(other.xmax()); - let i_ymin = ymin.max(other.ymin()); - let i_ymax = ymax.min(other.ymax()); - let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); - - (i_area) / (b1_area + b2_area - i_area) - } - - #[inline(always)] - pub fn bbox(&self) -> BBox { - BBox::xywh(self.x, self.y, self.w, self.h) - } - - #[inline(always)] - pub fn xmax(&self) -> f32 { - self.x + self.w / 2. - } - - #[inline(always)] - pub fn ymax(&self) -> f32 { - self.y + self.h / 2. - } - - #[inline(always)] - pub fn xmin(&self) -> f32 { - self.x - self.w / 2. - } - - #[inline(always)] - pub fn ymin(&self) -> f32 { - self.y - self.h / 2. - } -} diff --git a/src/detector.rs b/src/detector.rs deleted file mode 100644 index 108c8ba..0000000 --- a/src/detector.rs +++ /dev/null @@ -1,248 +0,0 @@ -use crate::detection::Detection; -use crate::error::Error; - -use ndarray::prelude::*; -use onnx_model::*; - -const MODEL_DYNAMIC_INPUT_DIMENSION: i64 = -1; - -pub struct YoloDetectorConfig { - pub confidence_threshold: f32, - pub iou_threshold: f32, - pub classes: Vec, - pub class_map: Option>, -} - -impl YoloDetectorConfig { - pub fn new(confidence_threshold: f32, classes: Vec) -> Self { - Self { - confidence_threshold, - iou_threshold: 0.2, - classes, - class_map: None, - } - } -} - -pub struct YoloDetector { - model: OnnxInferenceModel, - config: YoloDetectorConfig, -} - -impl YoloDetector { - pub fn new( - model_src: &str, - config: YoloDetectorConfig, - device: InferenceDevice, - ) -> Result { - let model = OnnxInferenceModel::new(model_src, device)?; - - Ok(Self { model, config }) - } - - pub fn get_model_input_size(&self) -> Option<(u32, u32)> { - let mut input_dims = self.model.get_input_infos()[0].shape.dims.clone(); - let input_height = input_dims.pop().unwrap(); - let input_width = input_dims.pop().unwrap(); - if input_height == MODEL_DYNAMIC_INPUT_DIMENSION - && input_width == MODEL_DYNAMIC_INPUT_DIMENSION - { - None - } else { - Some((input_width as u32, input_height as u32)) - } - } - - pub fn detect( - &self, - frames: ArrayView4<'_, f32>, - fw: i32, - fh: i32, - with_crop: bool, - ) -> Result>, Error> { - let in_shape = frames.shape(); - let (in_w, in_h) = (in_shape[3], in_shape[2]); - let preditions = self.model.run(&[frames.into_dyn()])?.pop().unwrap(); - let shape = preditions.shape(); - let shape = [shape[0], shape[1], shape[2]]; - let arr = preditions.into_shape(shape).unwrap(); - let bboxes = self.postprocess(arr.view(), in_w, in_h, fw, fh, with_crop)?; - - Ok(bboxes) - } - - fn postprocess( - &self, - view: ArrayView3<'_, f32>, - in_w: usize, - in_h: usize, - frame_width: i32, - frame_height: i32, - with_crop: bool, - ) -> Result>, Error> { - let shape = view.shape(); - let nbatches = shape[0]; - let npreds = shape[1]; - let pred_size = shape[2]; - let mut results: Vec> = (0..nbatches).map(|_| vec![]).collect(); - - let (ox, oy, ow, oh) = if with_crop { - let in_a = in_h as f32 / in_w as f32; - let frame_a = frame_height as f32 / frame_width as f32; - - if in_a > frame_a { - let w = frame_height as f32 / in_a; - ((frame_width as f32 - w) / 2.0, 0.0, w, frame_height as f32) - } else { - let h = frame_width as f32 * in_a; - (0.0, (frame_height as f32 - h) / 2.0, frame_width as f32, h) - } - } else { - (0.0, 0.0, frame_width as f32, frame_height as f32) - }; - - // Extract the bounding boxes for which confidence is above the threshold. - for batch in 0..nbatches { - let results = &mut results[batch]; - - // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec> = (0..80).map(|_| vec![]).collect(); - - for index in 0..npreds { - let x_0 = view.index_axis(Axis(0), batch); - let x_1 = x_0.index_axis(Axis(0), index); - let detection = x_1.as_slice().unwrap(); - - let (x, y, w, h) = match &detection[0..4] { - [center_x, center_y, width, height] => { - let center_x = ox + center_x * ow; - let center_y = oy + center_y * oh; - let width = width * ow as f32; - let height = height * oh as f32; - - (center_x, center_y, width, height) - } - - _ => unreachable!(), - }; - - let classes = &detection[4..pred_size]; - - let mut class_index = -1; - let mut confidence = 0.0; - - for (idx, val) in classes.iter().copied().enumerate() { - if val > confidence { - class_index = idx as i32; - confidence = val; - } - } - - if class_index > -1 && confidence > self.config.confidence_threshold { - if !self.config.classes.contains(&class_index) { - continue; - } - - if w * h > ((frame_width / 2) * (frame_height / 2)) as f32 { - continue; - } - - let mapped_class = match &self.config.class_map { - Some(map) => map - .get(class_index as usize) - .copied() - .unwrap_or(class_index), - None => class_index, - }; - - bboxes[mapped_class as usize].push(Detection { - x, - y, - w, - h, - confidence, - class: class_index as _, - }); - } - } - - for mut dets in bboxes.into_iter() { - if dets.is_empty() { - continue; - } - - if dets.len() == 1 { - results.append(&mut dets); - continue; - } - - let indices = self.non_maximum_supression(&mut dets)?; - - results.extend(dets.drain(..).enumerate().filter_map(|(idx, item)| { - if indices.contains(&(idx as i32)) { - Some(item) - } else { - None - } - })); - - // for (det, idx) in dets.into_iter().zip(indices) { - // if idx > -1 { - // results.push(det); - // } - // } - } - } - - Ok(results) - } - - // fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result, Error> { - // let mut rects = core::Vector::new(); - // let mut scores = core::Vector::new(); - - // for det in dets { - // rects.push(core::Rect2d::new( - // det.xmin as f64, - // det.ymin as f64, - // (det.xmax - det.xmin) as f64, - // (det.ymax - det.ymin) as f64 - // )); - // scores.push(det.confidence); - // } - - // let mut indices = core::Vector::::new(); - // dnn::nms_boxes_f64( - // &rects, - // &scores, - // self.config.confidence_threshold, - // self.config.iou_threshold, - // &mut indices, - // 1.0, - // 0 - // )?; - - // Ok(indices.to_vec()) - // } - - fn non_maximum_supression(&self, dets: &mut [Detection]) -> Result, Error> { - dets.sort_unstable_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap()); - - let mut retain: Vec<_> = (0..dets.len() as i32).collect(); - for idx in 0..dets.len() - 1 { - if retain[idx] != -1 { - for r in retain[idx + 1..].iter_mut() { - if *r != -1 { - let iou = dets[idx].iou(&dets[*r as usize]); - if iou > self.config.iou_threshold { - *r = -1; - } - } - } - } - } - - retain.retain(|&x| x > -1); - Ok(retain) - } -} diff --git a/src/frame.rs b/src/frame.rs index 9c91b0c..0eb6258 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -1,4 +1,4 @@ -use crate::detection::Detection; +use crate::Detection; pub struct Frame { pub dims: (u32, u32), diff --git a/src/lib.rs b/src/lib.rs index 4c7734f..c6f955a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,4 @@ pub mod bbox; -pub mod detection; -pub mod detector; pub mod error; pub mod frame; pub mod math; @@ -12,8 +10,8 @@ mod circular_queue; mod predictor; mod track; -pub use detection::Detection; pub use frame::Frame; +pub use object_detector::Detection; pub use track::Track; use error::Error; diff --git a/src/scene.rs b/src/scene.rs index 6c2c1b3..1cbc9ff 100644 --- a/src/scene.rs +++ b/src/scene.rs @@ -1,5 +1,6 @@ use std::sync::atomic::AtomicU32; +use crate::bbox::{BBox, Xywh}; use crate::tracker::Object; use crate::Detection; @@ -391,6 +392,8 @@ impl Participant { impl From<&Participant> for crate::Track { fn from(p: &Participant) -> crate::Track { + let bbox: BBox = p.last_detection().into(); + crate::Track { track_id: p.id as _, time_since_update: p.time_since_update as _, @@ -403,7 +406,7 @@ impl From<&Participant> for crate::Track { .unwrap_or(0) as _, confidence: p.hit_score_sum / p.hits_count as f32, iou_slip: p.iou_slip(), - bbox: p.last_detection().bbox().as_xyah(), + bbox: bbox.as_xyah(), velocity: Some(*p.velocity()), direction: Some((p.direction().re, p.direction().im)), curvature: Some(p.object.predictor.curvature),