deep-sort/examples/yolo.rs
Andrey Tkachenko ede337a2d8 Initial
2020-09-02 19:49:32 +04:00

211 lines
6.3 KiB
Rust

use anyhow::Error;
use grant_object_detector::{Detection, YoloDetector, YoloDetectorConfig};
use deep_sort::{sort, DeepSortConfig, DeepSort, Track};
use ndarray::prelude::*;
use opencv::{
dnn,
core::{self, Mat, Scalar, Vector},
highgui,
prelude::*,
videoio,
};
const PALLETE: (i32, i32, i32) = (2047, 32767, 1048575);
fn compute_color_for_labels(label: i32) -> (f64, f64, f64) {
let c = label * label - label + 1;
(
((PALLETE.0 * c) % 255) as _,
((PALLETE.1 * c) % 255) as _,
((PALLETE.2 * c) % 255) as _,
)
}
pub struct YoloDeepSort {
yolo: YoloDetector,
deep_sort: DeepSort,
}
impl YoloDeepSort {
fn new(yolo: &str, reid: &str) -> Result<Self, Error> {
let device = onnx_model::get_cuda_if_available();
let mut config = YoloDetectorConfig::new(vec![0, 2, 3, 5, 7]);
config.confidence_threshold = 0.2;
Ok(Self {
yolo: YoloDetector::new(yolo, config, device)?,
deep_sort: DeepSort::new(DeepSortConfig::new(reid.to_string()))?,
})
}
pub fn detectect(&mut self, frames: &[Mat]) -> Result<Vec<Vec<Detection>>, Error> {
let (mut frame_width, mut frame_height) = (0i32, 0i32);
const SIZE: usize = 416;
let mut inpt = unsafe { Array4::uninitialized([frames.len(), 3, SIZE, SIZE]) };
for (idx, frame) in frames.iter().enumerate() {
let fsize = frame.size()?;
if fsize.width <= 0 {
continue;
}
frame_height = fsize.height;
frame_width = fsize.width;
// Create a 4D blob from a frame.
let inp_width = SIZE as _;
let inp_height = SIZE as _;
let blob = dnn::blob_from_image(
&frame,
1.0 / 255.0,
core::Size::new(inp_width, inp_height),
core::Scalar::new(0., 0., 0., 0.),
true,
false,
core::CV_32F)
.unwrap();
let core = blob.try_into_typed::<f32>()?;
let view = aview1(core.data_typed()?).into_shape([3, SIZE, SIZE]).unwrap();
inpt.index_axis_mut(Axis(0), idx).assign(&view);
}
let detections = self.yolo.detect(inpt.view(), frame_width, frame_height)?;
Ok(detections)
}
pub fn track(&mut self, frames: &[Mat], detections: &[Vec<Detection>]) -> Result<&[Track], Error> {
self.deep_sort.update(frames, detections)?;
Ok(self.deep_sort.tracks())
}
}
fn draw_pred(frame: &mut Mat, det: Detection) -> opencv::Result<()> {
let rect = core::Rect::new(det.xmin, det.ymin, det.xmax - det.xmin, det.ymax - det.ymin);
// Draw a bounding box.
opencv::imgproc::rectangle(
frame,
rect,
core::Scalar::new(255.0, 255.0, 0.0, 0.0),
1,
opencv::imgproc::LINE_8,
0
)?;
Ok(())
}
fn draw_track(frame: &mut Mat, bbox: sort::BBox<sort::Ltwh>, track_id: i32, color: (f64, f64, f64)) -> opencv::Result<()> {
let rect = opencv::core::Rect::new(
bbox.left() as i32,
bbox.top() as i32,
bbox.width() as i32,
bbox.height() as i32,
);
// Draw a bounding box.
opencv::imgproc::rectangle(
frame,
rect,
core::Scalar::new(color.0, color.1, color.2, 0.0),
1,
opencv::imgproc::LINE_8,
0
)?;
// let label = format!("[{}]", track_id);
// let mut base_line = 0;
// let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
// let label_rect = core::Rect::new(
// rect.x,
// rect.y - label_size.height - 8,
// label_size.width + 8,
// label_size.height + 8
// );
// opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(0.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
// let pt = core::Point::new(rect.x, rect.y - 8);
// opencv::imgproc::put_text(
// frame,
// &label,
// pt,
// opencv::imgproc::FONT_HERSHEY_SIMPLEX,
// 0.6,
// core::Scalar::new(0.0, 0.0, 0.0, 0.0),
// 1,
// opencv::imgproc::LINE_8,
// false
// )?;
Ok(())
}
fn main() -> Result<(), anyhow::Error> {
let mut tracker = YoloDeepSort::new(
"/home/andrey/workspace/ssl/yolov4/yolov4_416.onnx",
// "/home/andrey/workspace/ssl/reid/onnx_model.onnx",
"/home/andrey/workspace/ssl/grant/models/model-96.onnx",
// "/home/andrey/workspace/ssl/deep_sort_pytorch/deep_sort/deep/reid.onnx",
)?;
let window = "video capture";
highgui::named_window(window, 1)?;
let mut cam = videoio::VideoCapture::from_file("../videoplayback_6.avi", videoio::CAP_ANY)?; // 0 is the default camera
// cam.set(videoio::CAP_PROP_POS_FRAMES, 150.0);
let opened = videoio::VideoCapture::is_opened(&cam)?;
if !opened {
panic!("Unable to open default camera!");
}
let mut frames = [core::Mat::default()?];
loop {
let begin = std::time::Instant::now();
cam.read(&mut frames[0])?;
let detections = tracker.detectect(&frames)?;
// for d in detections.iter().cloned() {
// for d in d {
// draw_pred(&mut frames[0], d);
// }
// }
let tracks = tracker.track(&frames, detections.as_slice())?;
for t in tracks.iter().filter(|t| t.is_confirmed() && t.time_since_update <= 1) {
draw_track(&mut frames[0], t.bbox().as_ltwh(), t.track_id, compute_color_for_labels(t.track_id));
}
let diff = std::time::Instant::now() - begin;
let label = format!("{:?}", 1.0 / ((diff.as_millis() as f32) * 0.001));
opencv::imgproc::put_text(
&mut frames[0],
&label,
core::Point::new(30, 30),
opencv::imgproc::FONT_HERSHEY_SIMPLEX,
0.6,
core::Scalar::new(0.0, 255.0, 0.0, 0.0),
1,
opencv::imgproc::LINE_8,
false
)?;
highgui::imshow(window, &mut frames[0])?;
let key = highgui::wait_key(10)?;
if key > 0 && key != 255 {
break;
}
}
Ok(())
}