Initial
This commit is contained in:
commit
ede337a2d8
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
22
Cargo.toml
Normal file
22
Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
[package]
|
||||||
|
name = "deep-sort"
|
||||||
|
version = "0.1.16"
|
||||||
|
authors = ["Andrey Tkachenko <andreytkachenko64@gmail.com>"]
|
||||||
|
edition = "2018"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
grant-object-detector = {path = "../grant-object-detector"}
|
||||||
|
ndarray-linalg = {version = "0.12.0", features = ["openblas"]}
|
||||||
|
munkres = "0.5.1"
|
||||||
|
openblas-src = {version = "0.9.0", features = ["static"]}
|
||||||
|
onnx-model = {path = "../onnx-model"}
|
||||||
|
opencv = "=0.41.0"
|
||||||
|
err-derive = "0.2.4"
|
||||||
|
|
||||||
|
[dependencies.ndarray]
|
||||||
|
version = "0.13.1"
|
||||||
|
features = ["blas"]
|
||||||
|
default-features = false
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = "1.0.31"
|
2
README.md
Normal file
2
README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# DeepSort Rust Implementation
|
||||||
|
See the arXiv preprint for more information.
|
607
examples/demo.rs
Normal file
607
examples/demo.rs
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
use opencv::{
|
||||||
|
dnn,
|
||||||
|
core::{self, Mat, Scalar, Vector},
|
||||||
|
highgui,
|
||||||
|
prelude::*,
|
||||||
|
videoio,
|
||||||
|
};
|
||||||
|
use deep_sort::{
|
||||||
|
deep::ImageEncoder,
|
||||||
|
sort,
|
||||||
|
};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
|
||||||
|
const CHANNELS: usize = 24;
|
||||||
|
const CONFIDENCE_THRESHOLD: f32 = 0.6;
|
||||||
|
const NMS_THRESHOLD: f32 = 0.4;
|
||||||
|
pub const NAMES: [&'static str; 80] = [
|
||||||
|
"person",
|
||||||
|
"bicycle",
|
||||||
|
"car",
|
||||||
|
"motorbike",
|
||||||
|
"aeroplane",
|
||||||
|
"bus",
|
||||||
|
"train",
|
||||||
|
"truck",
|
||||||
|
"boat",
|
||||||
|
"traffic light",
|
||||||
|
"fire hydrant",
|
||||||
|
"stop sign",
|
||||||
|
"parking meter",
|
||||||
|
"bench",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"sheep",
|
||||||
|
"cow",
|
||||||
|
"elephant",
|
||||||
|
"bear",
|
||||||
|
"zebra",
|
||||||
|
"giraffe",
|
||||||
|
"backpack",
|
||||||
|
"umbrella",
|
||||||
|
"handbag",
|
||||||
|
"tie",
|
||||||
|
"suitcase",
|
||||||
|
"frisbee",
|
||||||
|
"skis",
|
||||||
|
"snowboard",
|
||||||
|
"sports ball",
|
||||||
|
"kite",
|
||||||
|
"baseball bat",
|
||||||
|
"baseball glove",
|
||||||
|
"skateboard",
|
||||||
|
"surfboard",
|
||||||
|
"tennis racket",
|
||||||
|
"bottle",
|
||||||
|
"wine glass",
|
||||||
|
"cup",
|
||||||
|
"fork",
|
||||||
|
"knife",
|
||||||
|
"spoon",
|
||||||
|
"bowl",
|
||||||
|
"banana",
|
||||||
|
"apple",
|
||||||
|
"sandwich",
|
||||||
|
"orange",
|
||||||
|
"broccoli",
|
||||||
|
"carrot",
|
||||||
|
"hot dog",
|
||||||
|
"pizza",
|
||||||
|
"donut",
|
||||||
|
"cake",
|
||||||
|
"chair",
|
||||||
|
"sofa",
|
||||||
|
"pottedplant",
|
||||||
|
"bed",
|
||||||
|
"diningtable",
|
||||||
|
"toilet",
|
||||||
|
"tvmonitor",
|
||||||
|
"laptop",
|
||||||
|
"mouse",
|
||||||
|
"remote",
|
||||||
|
"keyboard",
|
||||||
|
"cell phone",
|
||||||
|
"microwave",
|
||||||
|
"oven",
|
||||||
|
"toaster",
|
||||||
|
"sink",
|
||||||
|
"refrigerator",
|
||||||
|
"book",
|
||||||
|
"clock",
|
||||||
|
"vase",
|
||||||
|
"scissors",
|
||||||
|
"teddy bear",
|
||||||
|
"hair drier",
|
||||||
|
"toothbrush",
|
||||||
|
];
|
||||||
|
|
||||||
|
fn run() -> opencv::Result<()> {
|
||||||
|
let mut encoder = ImageEncoder::new("/home/andrey/workspace/ssl/deep_sort_pytorch/deep_sort/deep/reid1.onnx")?;
|
||||||
|
let max_cosine_distance = 0.2;
|
||||||
|
let nn_budget = 100;
|
||||||
|
let max_age = 70;
|
||||||
|
let max_iou_distance = 0.2;
|
||||||
|
let n_init = 3;
|
||||||
|
let kind = sort::NearestNeighborMetricKind::CosineDistance;
|
||||||
|
let metric = sort::NearestNeighborDistanceMetric::new(kind, max_cosine_distance, Some(nn_budget));
|
||||||
|
let mut tracker = sort::Tracker::new(metric, max_iou_distance, max_age, n_init);
|
||||||
|
|
||||||
|
let model = "/home/andrey/workspace/ssl/yolov3/yolov3.weights";
|
||||||
|
let config = "/home/andrey/workspace/ssl/yolov3/yolov3.cfg";
|
||||||
|
let framework = "";
|
||||||
|
|
||||||
|
let mut net = dnn::read_net(model, config, framework).unwrap();
|
||||||
|
net.set_preferable_backend(dnn::DNN_BACKEND_DEFAULT);
|
||||||
|
net.set_preferable_target(dnn::DNN_TARGET_CPU);
|
||||||
|
|
||||||
|
let layer_names = net.get_layer_names()?;
|
||||||
|
let last_layer_id = net.get_layer_id(&layer_names.get(layer_names.len() - 1)?)?;
|
||||||
|
let last_layer = net.get_layer(dnn::DictValue::from_i32(last_layer_id)?)?;
|
||||||
|
let last_layer_type = last_layer.typ();
|
||||||
|
|
||||||
|
let out_names = net.get_unconnected_out_layers_names().unwrap();
|
||||||
|
|
||||||
|
let window = "video capture";
|
||||||
|
highgui::named_window(window, 1)?;
|
||||||
|
|
||||||
|
let mut cam = videoio::VideoCapture::new(0, videoio::CAP_ANY)?; // 0 is the default camera
|
||||||
|
|
||||||
|
let opened = videoio::VideoCapture::is_opened(&cam)?;
|
||||||
|
if !opened {
|
||||||
|
panic!("Unable to open default camera!");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut outs = core::Vector::<core::Mat>::new(); //core::Mat::default()?;
|
||||||
|
|
||||||
|
let mut frame = core::Mat::default()?;
|
||||||
|
let mut flag = -1i64;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
cam.read(&mut frame)?;
|
||||||
|
|
||||||
|
// flag += 1;
|
||||||
|
// if flag % 5 != 0 {
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
|
|
||||||
|
let fsize = frame.size()?;
|
||||||
|
|
||||||
|
if fsize.width <= 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame_height = fsize.height;
|
||||||
|
let frame_width = fsize.width;
|
||||||
|
|
||||||
|
// Create a 4D blob from a frame.
|
||||||
|
let inp_width = 416;
|
||||||
|
let inp_height = 416;
|
||||||
|
let blob = dnn::blob_from_image(
|
||||||
|
&frame,
|
||||||
|
1.0 / 255.0,
|
||||||
|
core::Size::new(inp_width, inp_height),
|
||||||
|
core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
core::CV_32F)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Run a model
|
||||||
|
net.set_input(&blob, "", 1.0, core::Scalar::new(0.,0.,0.,0.));
|
||||||
|
net.forward(&mut outs, &out_names).unwrap();
|
||||||
|
|
||||||
|
let fsize = frame.size()?;
|
||||||
|
|
||||||
|
let frame_height = fsize.height;
|
||||||
|
let frame_width = fsize.width;
|
||||||
|
// let mut objects = vec![];
|
||||||
|
|
||||||
|
match last_layer_type.as_str() {
|
||||||
|
"Region" => {
|
||||||
|
let mut detections = vec![];
|
||||||
|
let bboxes = detect(&outs)?;
|
||||||
|
|
||||||
|
for bbox in bboxes {
|
||||||
|
let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
|
||||||
|
let roi = Mat::roi(&frame, rect)?;
|
||||||
|
let blob = dnn::blob_from_image(
|
||||||
|
&roi,
|
||||||
|
1.0 / 255.0,
|
||||||
|
core::Size::new(64, 128),
|
||||||
|
core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
core::CV_32F)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let code = encoder.encode_batch(&blob)?.get(0)?;
|
||||||
|
let core = code.into_typed::<f32>()?;
|
||||||
|
let feature = arr1(core.data_typed()?);
|
||||||
|
|
||||||
|
detections.push(sort::Detection {
|
||||||
|
bbox: sort::BBox::ltwh(
|
||||||
|
rect.x as f32,
|
||||||
|
rect.y as f32,
|
||||||
|
rect.width as f32,
|
||||||
|
rect.height as f32
|
||||||
|
),
|
||||||
|
confidence: bbox.class_confidence,
|
||||||
|
feature: Some(feature)
|
||||||
|
});
|
||||||
|
|
||||||
|
draw_pred(&mut frame, bbox)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.predict();
|
||||||
|
tracker.update(detections.as_slice());
|
||||||
|
|
||||||
|
for t in tracker.tracks().iter().filter(|t|t.is_confirmed()) {
|
||||||
|
draw_track(&mut frame, t.bbox().as_ltwh(), t.track_id);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
_ => panic!("unknown last layer type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
highgui::imshow(window, &mut frame)?;
|
||||||
|
|
||||||
|
let key = highgui::wait_key(10)?;
|
||||||
|
if key > 0 && key != 255 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_pred(frame: &mut Mat, bbox: BBox) -> opencv::Result<()> {
|
||||||
|
let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(255.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// let label = format!("{} {:2}", NAMES[bbox.class_index], bbox.class_confidence);
|
||||||
|
// let mut base_line = 0;
|
||||||
|
// let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
|
||||||
|
|
||||||
|
// let label_rect = core::Rect::new(
|
||||||
|
// rect.x,
|
||||||
|
// rect.y - label_size.height - 8,
|
||||||
|
// label_size.width + 8,
|
||||||
|
// label_size.height + 8
|
||||||
|
// );
|
||||||
|
|
||||||
|
// opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(255.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
|
||||||
|
|
||||||
|
// let pt = core::Point::new(rect.x, rect.y - 8);
|
||||||
|
// opencv::imgproc::put_text(
|
||||||
|
// frame,
|
||||||
|
// &label,
|
||||||
|
// pt,
|
||||||
|
// opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
// 0.6,
|
||||||
|
// core::Scalar::new(0.0, 0.0, 0.0, 0.0),
|
||||||
|
// 1,
|
||||||
|
// opencv::imgproc::LINE_8,
|
||||||
|
// false
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_track(frame: &mut Mat, bbox: sort::BBox<sort::Ltwh>, track_id: i32) -> opencv::Result<()> {
|
||||||
|
let rect = opencv::core::Rect::new(
|
||||||
|
bbox.left() as i32,
|
||||||
|
bbox.top() as i32,
|
||||||
|
bbox.width() as i32,
|
||||||
|
bbox.height() as i32,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(0.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let label = format!("[{}]", track_id);
|
||||||
|
let mut base_line = 0;
|
||||||
|
let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
|
||||||
|
|
||||||
|
let label_rect = core::Rect::new(
|
||||||
|
rect.x,
|
||||||
|
rect.y - label_size.height - 8,
|
||||||
|
label_size.width + 8,
|
||||||
|
label_size.height + 8
|
||||||
|
);
|
||||||
|
|
||||||
|
opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(0.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
|
||||||
|
|
||||||
|
let pt = core::Point::new(rect.x, rect.y - 8);
|
||||||
|
opencv::imgproc::put_text(
|
||||||
|
frame,
|
||||||
|
&label,
|
||||||
|
pt,
|
||||||
|
opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
core::Scalar::new(0.0, 0.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
false
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct BBox {
|
||||||
|
xmin: f32,
|
||||||
|
ymin: f32,
|
||||||
|
xmax: f32,
|
||||||
|
ymax: f32,
|
||||||
|
confidence: f32,
|
||||||
|
class_index: usize,
|
||||||
|
class_confidence: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox {
|
||||||
|
pub fn cv_rect(&self, frame_width: i32, frame_height: i32) -> opencv::core::Rect {
|
||||||
|
let frame_width_f = frame_width as f32;
|
||||||
|
let frame_height_f = frame_height as f32;
|
||||||
|
|
||||||
|
let left = ((self.xmin * frame_width_f) as i32).max(0).min(frame_width);
|
||||||
|
let top = ((self.ymin * frame_height_f) as i32).max(0).min(frame_height);
|
||||||
|
let mut width = (((self.xmax - self.xmin) * frame_width_f) as i32).max(0);
|
||||||
|
let mut height = (((self.ymax - self.ymin) * frame_height_f) as i32).max(0);
|
||||||
|
|
||||||
|
if left + width > frame_width {
|
||||||
|
width = frame_width - left;
|
||||||
|
}
|
||||||
|
|
||||||
|
if top + height > frame_height {
|
||||||
|
height = frame_height - top;
|
||||||
|
}
|
||||||
|
|
||||||
|
core::Rect::new(
|
||||||
|
left,
|
||||||
|
top,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect(pred_events: &Vector<Mat>) -> opencv::Result<Vec<BBox>> {
|
||||||
|
|
||||||
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
let mut bboxes: Vec<(core::Vector<core::Rect2d>, core::Vector<f32>, Vec<BBox>)> = (0 .. 80)
|
||||||
|
.map(|_| (core::Vector::new(), core::Vector::new(), vec![]))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for pred_event in pred_events {
|
||||||
|
let fsize = pred_event.size()?;
|
||||||
|
let npreds = pred_event.rows();
|
||||||
|
let pred_size = pred_event.cols();
|
||||||
|
let nclasses = (pred_size - 5) as usize;
|
||||||
|
|
||||||
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
|
for index in 0 .. npreds {
|
||||||
|
let pred = pred_event.row(index)?.into_typed::<f32>()?;
|
||||||
|
let detection = pred.data_typed()?;
|
||||||
|
|
||||||
|
let (center_x, center_y, width, height, confidence) = match &detection[0 .. 5] {
|
||||||
|
&[a,b,c,d,e] => (a,b,c,d,e),
|
||||||
|
_ => unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let classes = &detection[5..];
|
||||||
|
|
||||||
|
if confidence > CONFIDENCE_THRESHOLD {
|
||||||
|
let mut class_index = -1;
|
||||||
|
let mut score = 0.0;
|
||||||
|
|
||||||
|
for (idx, &val) in classes.iter().enumerate() {
|
||||||
|
if val > score {
|
||||||
|
class_index = idx as i32;
|
||||||
|
score = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if class_index > -1 && score > 0. {
|
||||||
|
let entry = &mut bboxes[class_index as usize];
|
||||||
|
|
||||||
|
entry.0.push(core::Rect2d::new(
|
||||||
|
(center_x - width / 2.) as f64,
|
||||||
|
(center_y - height / 2.) as f64,
|
||||||
|
width as f64,
|
||||||
|
height as f64,
|
||||||
|
));
|
||||||
|
entry.1.push(score);
|
||||||
|
entry.2.push(BBox {
|
||||||
|
xmin: center_x - width / 2.,
|
||||||
|
ymin: center_y - height / 2.,
|
||||||
|
xmax: center_x + width / 2.,
|
||||||
|
ymax: center_y + height / 2.,
|
||||||
|
|
||||||
|
confidence,
|
||||||
|
class_index: class_index as _,
|
||||||
|
class_confidence: score,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut events = vec![];
|
||||||
|
|
||||||
|
for (rects, scores, bboxes) in bboxes.iter_mut() {
|
||||||
|
if bboxes.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut indices = core::Vector::<i32>::new();
|
||||||
|
dnn::nms_boxes_f64(
|
||||||
|
&rects,
|
||||||
|
&scores,
|
||||||
|
CONFIDENCE_THRESHOLD,
|
||||||
|
NMS_THRESHOLD,
|
||||||
|
&mut indices,
|
||||||
|
1.0,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut indices = indices.to_vec();
|
||||||
|
|
||||||
|
events.extend(bboxes.drain(..)
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(idx, item)| if indices.contains(&(idx as i32)) {Some(item)} else {None}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform non-maximum suppression.
|
||||||
|
// for (idx, (_, _, bboxes_for_class)) in bboxes.iter_mut().enumerate() {
|
||||||
|
// if bboxes_for_class.is_empty() {
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bboxes_for_class.sort_unstable_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||||
|
// let mut current_index = 0;
|
||||||
|
|
||||||
|
// for index in 0 .. bboxes_for_class.len() {
|
||||||
|
// let mut drop = false;
|
||||||
|
// for prev_index in 0..current_index {
|
||||||
|
// let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||||
|
// if iou > NMS_THRESHOLD {
|
||||||
|
// drop = true;
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if !drop {
|
||||||
|
// bboxes_for_class.swap(current_index, index);
|
||||||
|
// current_index += 1;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bboxes_for_class.truncate(current_index);
|
||||||
|
// }
|
||||||
|
|
||||||
|
for (class_index, (_, _, bboxes_for_class)) in bboxes.into_iter().enumerate() {
|
||||||
|
if bboxes_for_class.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let clamp = |x| if x < 0.0 { 0.0 } else if x > 1.0 { 1.0 } else { x };
|
||||||
|
|
||||||
|
for bbox in bboxes_for_class {
|
||||||
|
events.push(bbox);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intersection over union of two bounding boxes.
|
||||||
|
fn iou(b1: &BBox, b2: &BBox) -> f32 {
|
||||||
|
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||||
|
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||||
|
let i_xmin = b1.xmin.max(b2.xmin);
|
||||||
|
let i_xmax = b1.xmax.min(b2.xmax);
|
||||||
|
let i_ymin = b1.ymin.max(b2.ymin);
|
||||||
|
let i_ymax = b1.ymax.min(b2.ymax);
|
||||||
|
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||||
|
i_area / (b1_area + b2_area - i_area)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn postprocess(frame: &mut Mat, outs: &core::Vector<Mat>, last_layer_type: &str) -> opencv::Result<()> {
|
||||||
|
let fsize = frame.size()?;
|
||||||
|
|
||||||
|
let frame_height = fsize.height;
|
||||||
|
let frame_width = fsize.width;
|
||||||
|
// let mut objects = vec![];
|
||||||
|
|
||||||
|
match last_layer_type {
|
||||||
|
"Region" => {
|
||||||
|
// let bboxes = detect(&outs)?;
|
||||||
|
|
||||||
|
// for bbox in bboxes {
|
||||||
|
// draw_pred(frame, bbox);
|
||||||
|
// }
|
||||||
|
},
|
||||||
|
|
||||||
|
_ => panic!("unknown last layer type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// classIds = []
|
||||||
|
// confidences = []
|
||||||
|
// boxes = []
|
||||||
|
// if lastLayer.type == 'DetectionOutput':
|
||||||
|
// # Network produces output blob with a shape 1x1xNx7 where N is a number of
|
||||||
|
// # detections and an every detection is a vector of values
|
||||||
|
// # [batchId, classId, confidence, left, top, right, bottom]
|
||||||
|
// for out in outs:
|
||||||
|
// for detection in out[0, 0]:
|
||||||
|
// confidence = detection[2]
|
||||||
|
// if confidence > confThreshold:
|
||||||
|
// left = int(detection[3])
|
||||||
|
// top = int(detection[4])
|
||||||
|
// right = int(detection[5])
|
||||||
|
// bottom = int(detection[6])
|
||||||
|
// width = right - left + 1
|
||||||
|
// height = bottom - top + 1
|
||||||
|
// if width <= 2 or height <= 2:
|
||||||
|
// left = int(detection[3] * frameWidth)
|
||||||
|
// top = int(detection[4] * frameHeight)
|
||||||
|
// right = int(detection[5] * frameWidth)
|
||||||
|
// bottom = int(detection[6] * frameHeight)
|
||||||
|
// width = right - left + 1
|
||||||
|
// height = bottom - top + 1
|
||||||
|
// classIds.append(int(detection[1]) - 1) # Skip background label
|
||||||
|
// confidences.append(float(confidence))
|
||||||
|
// boxes.append([left, top, width, height])
|
||||||
|
// elif lastLayer.type == 'Region':
|
||||||
|
// # Network produces output blob with a shape NxC where N is a number of
|
||||||
|
// # detected objects and C is a number of classes + 4 where the first 4
|
||||||
|
// # numbers are [center_x, center_y, width, height]
|
||||||
|
// for out in outs:
|
||||||
|
// for detection in out:
|
||||||
|
// scores = detection[5:]
|
||||||
|
// classId = np.argmax(scores)
|
||||||
|
// confidence = scores[classId]
|
||||||
|
// if confidence > confThreshold:
|
||||||
|
// center_x = int(detection[0] * frameWidth)
|
||||||
|
// center_y = int(detection[1] * frameHeight)
|
||||||
|
// width = int(detection[2] * frameWidth)
|
||||||
|
// height = int(detection[3] * frameHeight)
|
||||||
|
// left = int(center_x - width / 2)
|
||||||
|
// top = int(center_y - height / 2)
|
||||||
|
// classIds.append(classId)
|
||||||
|
// confidences.append(float(confidence))
|
||||||
|
// boxes.append([left, top, width, height])
|
||||||
|
// else:
|
||||||
|
// print('Unknown output layer type: ' + lastLayer.type)
|
||||||
|
// exit()
|
||||||
|
|
||||||
|
// # NMS is used inside Region layer only on DNN_BACKEND_OPENCV for another backends we need NMS in sample
|
||||||
|
// # or NMS is required if number of outputs > 1
|
||||||
|
// if len(outNames) > 1 or lastLayer.type == 'Region' and args.backend != cv.dnn.DNN_BACKEND_OPENCV:
|
||||||
|
// indices = []
|
||||||
|
// classIds = np.array(classIds)
|
||||||
|
// boxes = np.array(boxes)
|
||||||
|
// confidences = np.array(confidences)
|
||||||
|
// unique_classes = set(classIds)
|
||||||
|
// for cl in unique_classes:
|
||||||
|
// class_indices = np.where(classIds == cl)[0]
|
||||||
|
// conf = confidences[class_indices]
|
||||||
|
// box = boxes[class_indices].tolist()
|
||||||
|
// nms_indices = cv.dnn.NMSBoxes(box, conf, confThreshold, nmsThreshold)
|
||||||
|
// nms_indices = nms_indices[:, 0] if len(nms_indices) else []
|
||||||
|
// indices.extend(class_indices[nms_indices])
|
||||||
|
// else:
|
||||||
|
// indices = np.arange(0, len(classIds))
|
||||||
|
|
||||||
|
// for i in indices:
|
||||||
|
// box = boxes[i]
|
||||||
|
// left = box[0]
|
||||||
|
// top = box[1]
|
||||||
|
// width = box[2]
|
||||||
|
// height = box[3]
|
||||||
|
// drawPred(classIds[i], confidences[i], left, top, left + width, top + height)
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
run().unwrap()
|
||||||
|
}
|
636
examples/file.rs
Normal file
636
examples/file.rs
Normal file
@ -0,0 +1,636 @@
|
|||||||
|
use opencv::{
|
||||||
|
dnn,
|
||||||
|
core::{self, Mat, Scalar, Vector},
|
||||||
|
highgui,
|
||||||
|
prelude::*,
|
||||||
|
videoio,
|
||||||
|
};
|
||||||
|
use deep_sort::{
|
||||||
|
deep::{ImageEncoder, ORT_ENV},
|
||||||
|
sort,
|
||||||
|
};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use onnxruntime::*;
|
||||||
|
use std::borrow::Borrow;
|
||||||
|
|
||||||
|
const SIZE: usize = 416;
|
||||||
|
const OUT_SIZE: usize = 10647; //6300; //16128 10647;
|
||||||
|
|
||||||
|
const CHANNELS: usize = 24;
|
||||||
|
const CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||||
|
const NMS_THRESHOLD: f32 = 0.4;
|
||||||
|
pub const NAMES: [&'static str; 80] = [
|
||||||
|
"person",
|
||||||
|
"bicycle",
|
||||||
|
"car",
|
||||||
|
"motorbike",
|
||||||
|
"aeroplane",
|
||||||
|
"bus",
|
||||||
|
"train",
|
||||||
|
"truck",
|
||||||
|
"boat",
|
||||||
|
"traffic light",
|
||||||
|
"fire hydrant",
|
||||||
|
"stop sign",
|
||||||
|
"parking meter",
|
||||||
|
"bench",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"sheep",
|
||||||
|
"cow",
|
||||||
|
"elephant",
|
||||||
|
"bear",
|
||||||
|
"zebra",
|
||||||
|
"giraffe",
|
||||||
|
"backpack",
|
||||||
|
"umbrella",
|
||||||
|
"handbag",
|
||||||
|
"tie",
|
||||||
|
"suitcase",
|
||||||
|
"frisbee",
|
||||||
|
"skis",
|
||||||
|
"snowboard",
|
||||||
|
"sports ball",
|
||||||
|
"kite",
|
||||||
|
"baseball bat",
|
||||||
|
"baseball glove",
|
||||||
|
"skateboard",
|
||||||
|
"surfboard",
|
||||||
|
"tennis racket",
|
||||||
|
"bottle",
|
||||||
|
"wine glass",
|
||||||
|
"cup",
|
||||||
|
"fork",
|
||||||
|
"knife",
|
||||||
|
"spoon",
|
||||||
|
"bowl",
|
||||||
|
"banana",
|
||||||
|
"apple",
|
||||||
|
"sandwich",
|
||||||
|
"orange",
|
||||||
|
"broccoli",
|
||||||
|
"carrot",
|
||||||
|
"hot dog",
|
||||||
|
"pizza",
|
||||||
|
"donut",
|
||||||
|
"cake",
|
||||||
|
"chair",
|
||||||
|
"sofa",
|
||||||
|
"pottedplant",
|
||||||
|
"bed",
|
||||||
|
"diningtable",
|
||||||
|
"toilet",
|
||||||
|
"tvmonitor",
|
||||||
|
"laptop",
|
||||||
|
"mouse",
|
||||||
|
"remote",
|
||||||
|
"keyboard",
|
||||||
|
"cell phone",
|
||||||
|
"microwave",
|
||||||
|
"oven",
|
||||||
|
"toaster",
|
||||||
|
"sink",
|
||||||
|
"refrigerator",
|
||||||
|
"book",
|
||||||
|
"clock",
|
||||||
|
"vase",
|
||||||
|
"scissors",
|
||||||
|
"teddy bear",
|
||||||
|
"hair drier",
|
||||||
|
"toothbrush",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub enum Target {
|
||||||
|
Cpu,
|
||||||
|
Cuda,
|
||||||
|
Tensorrt,
|
||||||
|
Movidus
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Inference {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Inference {
|
||||||
|
pub fn new(model: &str, config: &str) -> opencv::Result<Self> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_preferable_target(&mut self, target: Target) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_input(&mut self, name: &str, tensor: Tensor<f32>) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_output(&self, name: &str) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
OpenCv(opencv::Error),
|
||||||
|
OnnxRuntime(onnxruntime::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Error {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::OpenCv(i) => writeln!(f, "Error: {}", i),
|
||||||
|
Self::OnnxRuntime(i) => writeln!(f, "Error: {}", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
|
impl From<opencv::Error> for Error {
|
||||||
|
fn from(err: opencv::Error) -> Self {
|
||||||
|
Self::OpenCv(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl From<onnxruntime::Error> for Error {
|
||||||
|
fn from(err: onnxruntime::Error) -> Self {
|
||||||
|
Self::OnnxRuntime(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn run() -> std::result::Result<(), Error> {
|
||||||
|
let max_cosine_distance = 0.2;
|
||||||
|
let nn_budget = 100;
|
||||||
|
let max_age = 70;
|
||||||
|
let max_iou_distance = 0.7;
|
||||||
|
let n_init = 3;
|
||||||
|
let kind = sort::NearestNeighborMetricKind::CosineDistance;
|
||||||
|
let metric = sort::NearestNeighborDistanceMetric::new(kind, max_cosine_distance, Some(nn_budget));
|
||||||
|
let mut tracker = sort::Tracker::new(metric, max_iou_distance, max_age, n_init);
|
||||||
|
|
||||||
|
let window = "video capture";
|
||||||
|
highgui::named_window(window, 1)?;
|
||||||
|
|
||||||
|
let mut cam = videoio::VideoCapture::from_file("./videoplayback.mp4", videoio::CAP_ANY)?; // 0 is the default camera
|
||||||
|
cam.set(videoio::CAP_PROP_POS_FRAMES, 150.0);
|
||||||
|
|
||||||
|
let opened = videoio::VideoCapture::is_opened(&cam)?;
|
||||||
|
if !opened {
|
||||||
|
panic!("Unable to open default camera!");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut frame = core::Mat::default()?;
|
||||||
|
let mut flag = -1i64;
|
||||||
|
|
||||||
|
let (in_tx, in_rx): (std::sync::mpsc::Sender<(Tensor<f32>, Mat)>, _) = std::sync::mpsc::channel();
|
||||||
|
let (out_tx, out_rx) = std::sync::mpsc::channel();
|
||||||
|
let (bb_tx, bb_rx) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let mut so = SessionOptions::new().unwrap();
|
||||||
|
let ro = RunOptions::new();
|
||||||
|
|
||||||
|
// so.set_execution_mode(ExecutionMode::Parallel).unwrap();
|
||||||
|
// so.add_tensorrt(0);
|
||||||
|
so.add_cuda(0);
|
||||||
|
// so.add_cpu(true);
|
||||||
|
|
||||||
|
let session = Session::new(&ORT_ENV, "/home/andrey/workspace/ssl/yolov4/yolov4_416.onnx", &so).unwrap();
|
||||||
|
let mut out_vals = Tensor::<f32>::init(&[1, OUT_SIZE as _, 84], 0.0).unwrap();
|
||||||
|
|
||||||
|
let input = std::ffi::CStr::from_bytes_with_nul(b"input\0").unwrap();
|
||||||
|
let output = std::ffi::CStr::from_bytes_with_nul(b"output\0").unwrap();
|
||||||
|
|
||||||
|
while let Ok((in_vals, frame)) = in_rx.recv() {
|
||||||
|
let in_vals: Tensor<f32> = in_vals;
|
||||||
|
session
|
||||||
|
.run_mut(&ro, &[input], &[in_vals.as_ref()], &[output], &mut [out_vals.as_mut()])
|
||||||
|
.expect("run");
|
||||||
|
|
||||||
|
let xx: &[f32] = out_vals.borrow();
|
||||||
|
|
||||||
|
let arr = Array3::from_shape_vec([1, OUT_SIZE, 84], xx.to_vec()).unwrap();
|
||||||
|
|
||||||
|
out_tx.send((arr, frame)).unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let mut encoder = ImageEncoder::new("/home/andrey/workspace/ssl/deep_sort_pytorch/deep_sort/deep/reid.onnx").unwrap();
|
||||||
|
|
||||||
|
while let Ok((out_vals, frame)) = out_rx.recv() {
|
||||||
|
let mut preds = vec![];
|
||||||
|
let mut tracks = vec![];
|
||||||
|
|
||||||
|
// let mut detections = vec![];
|
||||||
|
// let out_shape = out_vals.shape();
|
||||||
|
let bboxes = detect(out_vals.view()).unwrap();
|
||||||
|
|
||||||
|
let nboxes = bboxes.len();
|
||||||
|
// let mut in_vals = Array4::from_elem([nboxes, 3, 128, 64], 0.0);
|
||||||
|
|
||||||
|
// for (index, bbox) in bboxes.iter().enumerate() {
|
||||||
|
// let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
|
||||||
|
// let roi = Mat::roi(&frame, rect).unwrap();
|
||||||
|
// let blob = dnn::blob_from_image(
|
||||||
|
// &roi,
|
||||||
|
// 1.0 / 255.0,
|
||||||
|
// core::Size::new(64, 128),
|
||||||
|
// core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
// true,
|
||||||
|
// false,
|
||||||
|
// core::CV_32F)
|
||||||
|
// .unwrap();
|
||||||
|
|
||||||
|
// let core = blob.into_typed::<f32>().unwrap();
|
||||||
|
// let data: &[f32] = core.data_typed().unwrap();
|
||||||
|
|
||||||
|
// let a = aview1(data).into_shape((3, 128, 64)).unwrap();
|
||||||
|
// in_vals.index_axis_mut(Axis(0), index)
|
||||||
|
// .assign(&a)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let t = TensorView::new(in_vals.shape(), in_vals.as_slice().unwrap());
|
||||||
|
// let code = encoder.encode_batch(t).unwrap();
|
||||||
|
// let features = aview1(code.borrow()).into_shape((nboxes, 512)).unwrap();
|
||||||
|
|
||||||
|
for (i, bbox) in bboxes.into_iter().enumerate() {
|
||||||
|
let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
// let feature = features.index_axis(Axis(0), i);
|
||||||
|
|
||||||
|
if bbox.class_index <= 8 {
|
||||||
|
// detections.push(sort::Detection {
|
||||||
|
// bbox: sort::BBox::ltwh(
|
||||||
|
// rect.x as f32,
|
||||||
|
// rect.y as f32,
|
||||||
|
// rect.width as f32,
|
||||||
|
// rect.height as f32
|
||||||
|
// ),
|
||||||
|
// confidence: bbox.confidence,
|
||||||
|
// feature: Some(feature.into_owned())
|
||||||
|
// });
|
||||||
|
|
||||||
|
preds.push(bbox);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tracker.predict();
|
||||||
|
// tracker.update(detections.as_slice());
|
||||||
|
|
||||||
|
// for t in tracker.tracks().iter().filter(|t| t.is_confirmed() && t.time_since_update <= 1) {
|
||||||
|
// tracks.push((t.bbox().as_ltwh(), t.track_id));
|
||||||
|
// // draw_track(&mut frame, t.bbox().as_ltwh(), t.track_id);
|
||||||
|
// }
|
||||||
|
|
||||||
|
bb_tx.send((preds, tracks)).unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let begin = std::time::Instant::now();
|
||||||
|
cam.read(&mut frame)?;
|
||||||
|
|
||||||
|
let fsize = frame.size()?;
|
||||||
|
|
||||||
|
if fsize.width <= 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame_height = fsize.height;
|
||||||
|
let frame_width = fsize.width;
|
||||||
|
|
||||||
|
// Create a 4D blob from a frame.
|
||||||
|
let inp_width = SIZE as _;
|
||||||
|
let inp_height = SIZE as _;
|
||||||
|
let blob = dnn::blob_from_image(
|
||||||
|
&frame,
|
||||||
|
1.0 / 255.0,
|
||||||
|
core::Size::new(inp_width, inp_height),
|
||||||
|
core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
core::CV_32F)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let core = blob.try_into_typed::<f32>()?;
|
||||||
|
let data: &[f32] = core.data_typed()?;
|
||||||
|
let in_vals = Tensor::new(&[1, 3, SIZE as _, SIZE as _], data.to_vec()).unwrap();
|
||||||
|
|
||||||
|
// // Run a model
|
||||||
|
|
||||||
|
in_tx.send((in_vals, frame.clone()?)).unwrap();
|
||||||
|
let (preds, tracks) = bb_rx.recv().unwrap();
|
||||||
|
|
||||||
|
for p in preds {
|
||||||
|
draw_pred(&mut frame, p)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (t, i) in tracks {
|
||||||
|
draw_track(&mut frame, t, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// let mut objects = vec![];
|
||||||
|
|
||||||
|
// let mut detections = vec![];
|
||||||
|
// // let out_shape = out_vals.dims();
|
||||||
|
// let bboxes = detect(aview1(out_vals.borrow())
|
||||||
|
// .into_shape([out_shape[0] as usize, out_shape[1] as usize, out_shape[2]as usize]).unwrap())?;
|
||||||
|
|
||||||
|
// for bbox in bboxes {
|
||||||
|
// let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
|
||||||
|
// let roi = Mat::roi(&frame, rect)?;
|
||||||
|
// let blob = dnn::blob_from_image(
|
||||||
|
// &roi,
|
||||||
|
// 1.0 / 255.0,
|
||||||
|
// core::Size::new(64, 128),
|
||||||
|
// core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
// true,
|
||||||
|
// false,
|
||||||
|
// core::CV_32F)
|
||||||
|
// .unwrap();
|
||||||
|
|
||||||
|
// let code = encoder.encode_batch(&blob)?.get(0)?;
|
||||||
|
// let core = code.into_typed::<f32>()?;
|
||||||
|
// let feature = arr1(core.data_typed()?);
|
||||||
|
|
||||||
|
// if bbox.class_index <= 8 {
|
||||||
|
// detections.push(sort::Detection {
|
||||||
|
// bbox: sort::BBox::ltwh(
|
||||||
|
// rect.x as f32,
|
||||||
|
// rect.y as f32,
|
||||||
|
// rect.width as f32,
|
||||||
|
// rect.height as f32
|
||||||
|
// ),
|
||||||
|
// confidence: bbox.confidence,
|
||||||
|
// feature: Some(feature)
|
||||||
|
// });
|
||||||
|
|
||||||
|
// draw_pred(&mut frame, bbox)?;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// tracker.predict();
|
||||||
|
// tracker.update(detections.as_slice());
|
||||||
|
|
||||||
|
// for t in tracker.tracks().iter().filter(|t| t.is_confirmed() && t.time_since_update <= 1) {
|
||||||
|
// draw_track(&mut frame, t.bbox().as_ltwh(), t.track_id);
|
||||||
|
// }
|
||||||
|
|
||||||
|
let diff = std::time::Instant::now() - begin;
|
||||||
|
let label = format!("{:?}", 1.0 / ((diff.as_millis() as f32) * 0.001));
|
||||||
|
|
||||||
|
opencv::imgproc::put_text(
|
||||||
|
&mut frame,
|
||||||
|
&label,
|
||||||
|
core::Point::new(30, 30),
|
||||||
|
opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
core::Scalar::new(0.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
false
|
||||||
|
)?;
|
||||||
|
|
||||||
|
highgui::imshow(window, &mut frame)?;
|
||||||
|
|
||||||
|
let key = highgui::wait_key(10)?;
|
||||||
|
if key > 0 && key != 255 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn detect(pred_events: &ArrayView3<'_, f32>) -> opencv::Result<Vec<BBox>> {
|
||||||
|
|
||||||
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
let mut bboxes: Vec<(core::Vector<core::Rect2d>, core::Vector<f32>, Vec<BBox>)> = (0 .. 80)
|
||||||
|
.map(|_| (core::Vector::new(), core::Vector::new(), vec![]))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for pred_event in pred_events {
|
||||||
|
let fsize = pred_event.size()?;
|
||||||
|
let npreds = pred_event.rows();
|
||||||
|
let pred_size = pred_event.cols();
|
||||||
|
let nclasses = (pred_size - 5) as usize;
|
||||||
|
|
||||||
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
|
for index in 0 .. npreds {
|
||||||
|
let pred = pred_event.row(index)?.try_into_typed::<f32>()?;
|
||||||
|
let detection = pred.data_typed()?;
|
||||||
|
|
||||||
|
let (center_x, center_y, width, height, confidence) = match &detection[0 .. 5] {
|
||||||
|
&[a,b,c,d,e] => (a,b,c,d,e),
|
||||||
|
_ => unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let classes = &detection[5..];
|
||||||
|
|
||||||
|
if confidence > CONFIDENCE_THRESHOLD {
|
||||||
|
let mut class_index = -1;
|
||||||
|
let mut score = 0.0;
|
||||||
|
|
||||||
|
for (idx, &val) in classes.iter().enumerate() {
|
||||||
|
if val > score {
|
||||||
|
class_index = idx as i32;
|
||||||
|
score = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if class_index > -1 && score > 0. {
|
||||||
|
let entry = &mut bboxes[class_index as usize];
|
||||||
|
|
||||||
|
entry.0.push(core::Rect2d::new(
|
||||||
|
(center_x - width / 2.) as f64,
|
||||||
|
(center_y - height / 2.) as f64,
|
||||||
|
width as f64,
|
||||||
|
height as f64,
|
||||||
|
));
|
||||||
|
entry.1.push(score);
|
||||||
|
entry.2.push(BBox {
|
||||||
|
xmin: center_x - width / 2.,
|
||||||
|
ymin: center_y - height / 2.,
|
||||||
|
xmax: center_x + width / 2.,
|
||||||
|
ymax: center_y + height / 2.,
|
||||||
|
|
||||||
|
confidence,
|
||||||
|
class_index: class_index as _,
|
||||||
|
class_confidence: score,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut events = vec![];
|
||||||
|
|
||||||
|
for (rects, scores, bboxes) in bboxes.iter_mut() {
|
||||||
|
if bboxes.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut indices = core::Vector::<i32>::new();
|
||||||
|
dnn::nms_boxes_f64(
|
||||||
|
&rects,
|
||||||
|
&scores,
|
||||||
|
CONFIDENCE_THRESHOLD,
|
||||||
|
NMS_THRESHOLD,
|
||||||
|
&mut indices,
|
||||||
|
1.0,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut indices = indices.to_vec();
|
||||||
|
|
||||||
|
events.extend(bboxes.drain(..)
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(idx, item)| if indices.contains(&(idx as i32)) {Some(item)} else {None}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform non-maximum suppression.
|
||||||
|
// for (idx, (_, _, bboxes_for_class)) in bboxes.iter_mut().enumerate() {
|
||||||
|
// if bboxes_for_class.is_empty() {
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bboxes_for_class.sort_unstable_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||||
|
// let mut current_index = 0;
|
||||||
|
|
||||||
|
// for index in 0 .. bboxes_for_class.len() {
|
||||||
|
// let mut drop = false;
|
||||||
|
// for prev_index in 0..current_index {
|
||||||
|
// let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||||
|
// if iou > NMS_THRESHOLD {
|
||||||
|
// drop = true;
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if !drop {
|
||||||
|
// bboxes_for_class.swap(current_index, index);
|
||||||
|
// current_index += 1;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bboxes_for_class.truncate(current_index);
|
||||||
|
// }
|
||||||
|
|
||||||
|
for (class_index, (_, _, bboxes_for_class)) in bboxes.into_iter().enumerate() {
|
||||||
|
if bboxes_for_class.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let clamp = |x| if x < 0.0 { 0.0 } else if x > 1.0 { 1.0 } else { x };
|
||||||
|
|
||||||
|
for bbox in bboxes_for_class {
|
||||||
|
events.push(bbox);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_pred(frame: &mut Mat, bbox: BBox) -> opencv::Result<()> {
|
||||||
|
let rect = bbox.cv_rect(frame.cols(), frame.rows());
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(255.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// let label = format!("{} {:2}", NAMES[bbox.class_index], bbox.class_confidence);
|
||||||
|
// let mut base_line = 0;
|
||||||
|
// let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
|
||||||
|
|
||||||
|
// let label_rect = core::Rect::new(
|
||||||
|
// rect.x,
|
||||||
|
// rect.y - label_size.height - 8,
|
||||||
|
// label_size.width + 8,
|
||||||
|
// label_size.height + 8
|
||||||
|
// );
|
||||||
|
|
||||||
|
// opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(255.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
|
||||||
|
|
||||||
|
// let pt = core::Point::new(rect.x, rect.y - 8);
|
||||||
|
// opencv::imgproc::put_text(
|
||||||
|
// frame,
|
||||||
|
// &label,
|
||||||
|
// pt,
|
||||||
|
// opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
// 0.6,
|
||||||
|
// core::Scalar::new(0.0, 0.0, 0.0, 0.0),
|
||||||
|
// 1,
|
||||||
|
// opencv::imgproc::LINE_8,
|
||||||
|
// false
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_track(frame: &mut Mat, bbox: sort::BBox<sort::Ltwh>, track_id: i32) -> opencv::Result<()> {
|
||||||
|
let rect = opencv::core::Rect::new(
|
||||||
|
bbox.left() as i32,
|
||||||
|
bbox.top() as i32,
|
||||||
|
bbox.width() as i32,
|
||||||
|
bbox.height() as i32,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(0.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let label = format!("[{}]", track_id);
|
||||||
|
let mut base_line = 0;
|
||||||
|
let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
|
||||||
|
|
||||||
|
let label_rect = core::Rect::new(
|
||||||
|
rect.x,
|
||||||
|
rect.y - label_size.height - 8,
|
||||||
|
label_size.width + 8,
|
||||||
|
label_size.height + 8
|
||||||
|
);
|
||||||
|
|
||||||
|
opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(0.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
|
||||||
|
|
||||||
|
let pt = core::Point::new(rect.x, rect.y - 8);
|
||||||
|
opencv::imgproc::put_text(
|
||||||
|
frame,
|
||||||
|
&label,
|
||||||
|
pt,
|
||||||
|
opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
core::Scalar::new(0.0, 0.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
false
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
run().unwrap()
|
||||||
|
}
|
211
examples/yolo.rs
Normal file
211
examples/yolo.rs
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
use anyhow::Error;
|
||||||
|
use grant_object_detector::{Detection, YoloDetector, YoloDetectorConfig};
|
||||||
|
use deep_sort::{sort, DeepSortConfig, DeepSort, Track};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use opencv::{
|
||||||
|
dnn,
|
||||||
|
core::{self, Mat, Scalar, Vector},
|
||||||
|
highgui,
|
||||||
|
prelude::*,
|
||||||
|
videoio,
|
||||||
|
};
|
||||||
|
|
||||||
|
const PALLETE: (i32, i32, i32) = (2047, 32767, 1048575);
|
||||||
|
|
||||||
|
fn compute_color_for_labels(label: i32) -> (f64, f64, f64) {
|
||||||
|
let c = label * label - label + 1;
|
||||||
|
|
||||||
|
(
|
||||||
|
((PALLETE.0 * c) % 255) as _,
|
||||||
|
((PALLETE.1 * c) % 255) as _,
|
||||||
|
((PALLETE.2 * c) % 255) as _,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct YoloDeepSort {
|
||||||
|
yolo: YoloDetector,
|
||||||
|
deep_sort: DeepSort,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YoloDeepSort {
|
||||||
|
fn new(yolo: &str, reid: &str) -> Result<Self, Error> {
|
||||||
|
let device = onnx_model::get_cuda_if_available();
|
||||||
|
let mut config = YoloDetectorConfig::new(vec![0, 2, 3, 5, 7]);
|
||||||
|
config.confidence_threshold = 0.2;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
yolo: YoloDetector::new(yolo, config, device)?,
|
||||||
|
deep_sort: DeepSort::new(DeepSortConfig::new(reid.to_string()))?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn detectect(&mut self, frames: &[Mat]) -> Result<Vec<Vec<Detection>>, Error> {
|
||||||
|
let (mut frame_width, mut frame_height) = (0i32, 0i32);
|
||||||
|
const SIZE: usize = 416;
|
||||||
|
|
||||||
|
let mut inpt = unsafe { Array4::uninitialized([frames.len(), 3, SIZE, SIZE]) };
|
||||||
|
for (idx, frame) in frames.iter().enumerate() {
|
||||||
|
let fsize = frame.size()?;
|
||||||
|
|
||||||
|
if fsize.width <= 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
frame_height = fsize.height;
|
||||||
|
frame_width = fsize.width;
|
||||||
|
|
||||||
|
// Create a 4D blob from a frame.
|
||||||
|
let inp_width = SIZE as _;
|
||||||
|
let inp_height = SIZE as _;
|
||||||
|
let blob = dnn::blob_from_image(
|
||||||
|
&frame,
|
||||||
|
1.0 / 255.0,
|
||||||
|
core::Size::new(inp_width, inp_height),
|
||||||
|
core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
core::CV_32F)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let core = blob.try_into_typed::<f32>()?;
|
||||||
|
let view = aview1(core.data_typed()?).into_shape([3, SIZE, SIZE]).unwrap();
|
||||||
|
inpt.index_axis_mut(Axis(0), idx).assign(&view);
|
||||||
|
}
|
||||||
|
|
||||||
|
let detections = self.yolo.detect(inpt.view(), frame_width, frame_height)?;
|
||||||
|
|
||||||
|
Ok(detections)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn track(&mut self, frames: &[Mat], detections: &[Vec<Detection>]) -> Result<&[Track], Error> {
|
||||||
|
self.deep_sort.update(frames, detections)?;
|
||||||
|
|
||||||
|
Ok(self.deep_sort.tracks())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_pred(frame: &mut Mat, det: Detection) -> opencv::Result<()> {
|
||||||
|
let rect = core::Rect::new(det.xmin, det.ymin, det.xmax - det.xmin, det.ymax - det.ymin);
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(255.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn draw_track(frame: &mut Mat, bbox: sort::BBox<sort::Ltwh>, track_id: i32, color: (f64, f64, f64)) -> opencv::Result<()> {
|
||||||
|
let rect = opencv::core::Rect::new(
|
||||||
|
bbox.left() as i32,
|
||||||
|
bbox.top() as i32,
|
||||||
|
bbox.width() as i32,
|
||||||
|
bbox.height() as i32,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Draw a bounding box.
|
||||||
|
opencv::imgproc::rectangle(
|
||||||
|
frame,
|
||||||
|
rect,
|
||||||
|
core::Scalar::new(color.0, color.1, color.2, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
0
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// let label = format!("[{}]", track_id);
|
||||||
|
// let mut base_line = 0;
|
||||||
|
// let label_size = opencv::imgproc::get_text_size(&label, opencv::imgproc::FONT_HERSHEY_SIMPLEX, 0.6, 1, &mut base_line)?;
|
||||||
|
|
||||||
|
// let label_rect = core::Rect::new(
|
||||||
|
// rect.x,
|
||||||
|
// rect.y - label_size.height - 8,
|
||||||
|
// label_size.width + 8,
|
||||||
|
// label_size.height + 8
|
||||||
|
// );
|
||||||
|
|
||||||
|
// opencv::imgproc::rectangle(frame, label_rect, core::Scalar::new(0.0, 255.0, 0.0, 0.0), opencv::imgproc::FILLED, opencv::imgproc::LINE_8, 0)?;
|
||||||
|
|
||||||
|
// let pt = core::Point::new(rect.x, rect.y - 8);
|
||||||
|
// opencv::imgproc::put_text(
|
||||||
|
// frame,
|
||||||
|
// &label,
|
||||||
|
// pt,
|
||||||
|
// opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
// 0.6,
|
||||||
|
// core::Scalar::new(0.0, 0.0, 0.0, 0.0),
|
||||||
|
// 1,
|
||||||
|
// opencv::imgproc::LINE_8,
|
||||||
|
// false
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn main() -> Result<(), anyhow::Error> {
|
||||||
|
let mut tracker = YoloDeepSort::new(
|
||||||
|
"/home/andrey/workspace/ssl/yolov4/yolov4_416.onnx",
|
||||||
|
// "/home/andrey/workspace/ssl/reid/onnx_model.onnx",
|
||||||
|
"/home/andrey/workspace/ssl/grant/models/model-96.onnx",
|
||||||
|
// "/home/andrey/workspace/ssl/deep_sort_pytorch/deep_sort/deep/reid.onnx",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let window = "video capture";
|
||||||
|
highgui::named_window(window, 1)?;
|
||||||
|
|
||||||
|
let mut cam = videoio::VideoCapture::from_file("../videoplayback_6.avi", videoio::CAP_ANY)?; // 0 is the default camera
|
||||||
|
// cam.set(videoio::CAP_PROP_POS_FRAMES, 150.0);
|
||||||
|
|
||||||
|
let opened = videoio::VideoCapture::is_opened(&cam)?;
|
||||||
|
if !opened {
|
||||||
|
panic!("Unable to open default camera!");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut frames = [core::Mat::default()?];
|
||||||
|
loop {
|
||||||
|
let begin = std::time::Instant::now();
|
||||||
|
cam.read(&mut frames[0])?;
|
||||||
|
|
||||||
|
let detections = tracker.detectect(&frames)?;
|
||||||
|
// for d in detections.iter().cloned() {
|
||||||
|
// for d in d {
|
||||||
|
// draw_pred(&mut frames[0], d);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
let tracks = tracker.track(&frames, detections.as_slice())?;
|
||||||
|
for t in tracks.iter().filter(|t| t.is_confirmed() && t.time_since_update <= 1) {
|
||||||
|
draw_track(&mut frames[0], t.bbox().as_ltwh(), t.track_id, compute_color_for_labels(t.track_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
let diff = std::time::Instant::now() - begin;
|
||||||
|
let label = format!("{:?}", 1.0 / ((diff.as_millis() as f32) * 0.001));
|
||||||
|
|
||||||
|
opencv::imgproc::put_text(
|
||||||
|
&mut frames[0],
|
||||||
|
&label,
|
||||||
|
core::Point::new(30, 30),
|
||||||
|
opencv::imgproc::FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
core::Scalar::new(0.0, 255.0, 0.0, 0.0),
|
||||||
|
1,
|
||||||
|
opencv::imgproc::LINE_8,
|
||||||
|
false
|
||||||
|
)?;
|
||||||
|
|
||||||
|
highgui::imshow(window, &mut frames[0])?;
|
||||||
|
|
||||||
|
let key = highgui::wait_key(10)?;
|
||||||
|
if key > 0 && key != 255 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
38
src/deep/image_encoder.rs
Normal file
38
src/deep/image_encoder.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
use crate::error::Error;
|
||||||
|
use onnx_model::{OnnxInferenceModel, OnnxInferenceDevice, TensorView};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
|
||||||
|
pub struct ImageEncoder {
|
||||||
|
model: OnnxInferenceModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ImageEncoder {
|
||||||
|
pub fn new(model_filename: &str, device: OnnxInferenceDevice) -> Result<Self, Error> {
|
||||||
|
Ok(Self {
|
||||||
|
model: OnnxInferenceModel::new(model_filename, device)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn input_shape(&self) -> &[i64] {
|
||||||
|
let inpt = &self.model.get_input_infos()[0];
|
||||||
|
|
||||||
|
&inpt.shape.dims[1..]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn output_shape(&self) -> &[i64] {
|
||||||
|
let otpt = &self.model.get_output_infos()[0];
|
||||||
|
|
||||||
|
&otpt.shape.dims[1..]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_batch(&mut self, in_vals: ArrayView4<'_, f32>) -> std::result::Result<Array2<f32>, Error> {
|
||||||
|
let inpt = TensorView::new(in_vals.shape(), in_vals.as_slice().unwrap());
|
||||||
|
let otpt = self.model.run(&[inpt])?.pop().unwrap();
|
||||||
|
let shape = otpt.dims().clone();
|
||||||
|
let features = Array2::from_shape_vec([shape[0] as usize, shape[1] as usize], otpt.to_vec()).unwrap();
|
||||||
|
|
||||||
|
Ok(features)
|
||||||
|
}
|
||||||
|
}
|
3
src/deep/mod.rs
Normal file
3
src/deep/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mod image_encoder;
|
||||||
|
|
||||||
|
pub use image_encoder::ImageEncoder;
|
13
src/error.rs
Normal file
13
src/error.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
use err_derive::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error(display = "OnnxModel Error: {}", _0)]
|
||||||
|
OnnxModelError(onnx_model::error::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<onnx_model::error::Error> for Error {
|
||||||
|
fn from(err: onnx_model::error::Error) -> Self {
|
||||||
|
Self::OnnxModelError(err)
|
||||||
|
}
|
||||||
|
}
|
165
src/lib.rs
Normal file
165
src/lib.rs
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
pub mod deep;
|
||||||
|
pub mod sort;
|
||||||
|
pub mod error;
|
||||||
|
|
||||||
|
|
||||||
|
pub use sort::Track;
|
||||||
|
use deep::ImageEncoder;
|
||||||
|
use sort::{Tracker, NearestNeighborMetricKind, NearestNeighborDistanceMetric};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use opencv::core::{self, Mat, Rect};
|
||||||
|
use opencv::core::MatTrait;
|
||||||
|
|
||||||
|
use opencv::dnn;
|
||||||
|
use error::Error;
|
||||||
|
use grant_object_detector as detector;
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn get_detection_cv_rect(det: &detector::Detection, frame_width: i32, frame_height: i32) -> Rect {
|
||||||
|
let left = det.xmin.max(0).min(frame_width);
|
||||||
|
let top = det.ymin.max(0).min(frame_height);
|
||||||
|
|
||||||
|
let mut width = (det.xmax - det.xmin).max(0);
|
||||||
|
let mut height = (det.ymax - det.ymin).max(0);
|
||||||
|
|
||||||
|
if left + width > frame_width {
|
||||||
|
width = frame_width - left;
|
||||||
|
}
|
||||||
|
|
||||||
|
if top + height > frame_height {
|
||||||
|
height = frame_height - top;
|
||||||
|
}
|
||||||
|
|
||||||
|
Rect::new(
|
||||||
|
left,
|
||||||
|
top,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DeepSortConfig {
|
||||||
|
pub reid_model_path: String,
|
||||||
|
pub max_cosine_distance: f32,
|
||||||
|
pub nn_budget: Option<usize>,
|
||||||
|
pub max_age: i32,
|
||||||
|
pub max_iou_distance: f32,
|
||||||
|
pub n_init: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeepSortConfig {
|
||||||
|
pub fn new(reid_model_path: String) -> Self {
|
||||||
|
Self {
|
||||||
|
reid_model_path,
|
||||||
|
max_cosine_distance: 0.2,
|
||||||
|
nn_budget: Some(100),
|
||||||
|
max_age: 70,
|
||||||
|
max_iou_distance: 0.7,
|
||||||
|
n_init: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DeepSort {
|
||||||
|
device: onnx_model::OnnxInferenceDevice,
|
||||||
|
encoder: ImageEncoder,
|
||||||
|
sample_tracker: Tracker<NearestNeighborDistanceMetric>,
|
||||||
|
trackers: HashMap<String, Tracker<NearestNeighborDistanceMetric>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeepSort {
|
||||||
|
pub fn new(config: DeepSortConfig) -> Result<Self, Error> {
|
||||||
|
let metric = NearestNeighborDistanceMetric::new(
|
||||||
|
NearestNeighborMetricKind::CosineDistance,
|
||||||
|
config.max_cosine_distance,
|
||||||
|
config.nn_budget
|
||||||
|
);
|
||||||
|
|
||||||
|
let device = onnx_model::get_cuda_if_available();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
device,
|
||||||
|
sample_tracker: Tracker::new(
|
||||||
|
metric,
|
||||||
|
config.max_iou_distance,
|
||||||
|
config.max_age,
|
||||||
|
config.n_init
|
||||||
|
),
|
||||||
|
encoder: ImageEncoder::new(&config.reid_model_path, device)?,
|
||||||
|
trackers: HashMap::<String, Tracker<NearestNeighborDistanceMetric>>::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn tracks(&self, src_url: String) -> &[sort::Track] {
|
||||||
|
self.trackers.get(&src_url)
|
||||||
|
.unwrap()
|
||||||
|
.tracks()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, frames: &[Mat], dets: &[Vec<detector::Detection>], src_url: String) -> Result<(), Error> {
|
||||||
|
let total_count = dets.iter().map(|i|i.len()).sum::<usize>();
|
||||||
|
|
||||||
|
let tracker = self.trackers.entry(src_url.clone()).or_insert(self.sample_tracker.clone());
|
||||||
|
|
||||||
|
if total_count == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_shape = self.encoder.input_shape();
|
||||||
|
let frag_channs = input_shape[0] as usize;
|
||||||
|
let frag_height = input_shape[1] as usize;
|
||||||
|
let frag_width = input_shape[2] as usize;
|
||||||
|
|
||||||
|
let mut idets = unsafe { Array4::uninitialized([total_count, frag_channs, frag_height, frag_width]) };
|
||||||
|
let mut index = 0usize;
|
||||||
|
for (frame, dets) in frames.iter().zip(dets.iter()) {
|
||||||
|
for det in dets {
|
||||||
|
let rect = get_detection_cv_rect(&det, frame.cols(), frame.rows());
|
||||||
|
let roi = Mat::roi(&frame, rect).unwrap();
|
||||||
|
let blob = dnn::blob_from_image(
|
||||||
|
&roi,
|
||||||
|
1.0 / 255.0,
|
||||||
|
core::Size::new(frag_width as i32, frag_height as i32),
|
||||||
|
core::Scalar::new(0., 0., 0., 0.),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
core::CV_32F).unwrap();
|
||||||
|
|
||||||
|
let core = blob.try_into_typed::<f32>().unwrap();
|
||||||
|
let data: &[f32] = core.data_typed().unwrap();
|
||||||
|
let a = aview1(data).into_shape((frag_channs, frag_height, frag_width)).unwrap();
|
||||||
|
idets.index_axis_mut(Axis(0), index).assign(&a);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let features = self.encoder.encode_batch(idets.view())?;
|
||||||
|
// let features = Array2::zeros((dets.len(), 512));
|
||||||
|
let mut detections = vec![];
|
||||||
|
|
||||||
|
for dets in dets.iter() {
|
||||||
|
detections.clear();
|
||||||
|
for (feature, det) in features.axis_iter(Axis(0)).zip(dets.iter()) {
|
||||||
|
let x = sort::Detection {
|
||||||
|
bbox: sort::BBox::ltwh(
|
||||||
|
det.xmin as _,
|
||||||
|
det.ymin as _,
|
||||||
|
(det.xmax - det.xmin) as _,
|
||||||
|
(det.ymax - det.ymin) as _,
|
||||||
|
),
|
||||||
|
confidence: det.confidence,
|
||||||
|
feature: Some(feature.into_owned())
|
||||||
|
};
|
||||||
|
|
||||||
|
detections.push(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.predict();
|
||||||
|
tracker.update(detections.as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
17
src/sort/detection.rs
Normal file
17
src/sort/detection.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
use ndarray::prelude::*;
|
||||||
|
use crate::sort::{BBox, Ltwh};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// This class represents a bounding box detection in a single image.
|
||||||
|
/// Parameters
|
||||||
|
///
|
||||||
|
/// ltwh : BBox in format `(x, y, w, h)`.
|
||||||
|
/// confidence : f32 - Detector confidence score.
|
||||||
|
/// feature : Vec<Array1<f32>> A feature vector that describes the object contained in this image.
|
||||||
|
///
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Detection {
|
||||||
|
pub bbox: BBox<Ltwh>,
|
||||||
|
pub confidence: f32,
|
||||||
|
pub feature: Option<Array1<f32>>
|
||||||
|
}
|
89
src/sort/iou_matching.rs
Normal file
89
src/sort/iou_matching.rs
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
use ndarray::prelude::*;
|
||||||
|
|
||||||
|
use crate::sort::linear_assignment::INFTY_COST;
|
||||||
|
use crate::sort::{Track, Detection, BBox, Ltwh};
|
||||||
|
|
||||||
|
/// Computer intersection over union.
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// bbox : ndarray
|
||||||
|
/// A bounding box in format `(top left x, top left y, width, height)`.
|
||||||
|
/// candidates : ndarray
|
||||||
|
/// A matrix of candidate bounding boxes (one per row) in the same format
|
||||||
|
/// as `bbox`.
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// The intersection over union in [0, 1] between the `bbox` and each
|
||||||
|
/// candidate. A higher score means a larger fraction of the `bbox` is
|
||||||
|
/// occluded by the candidate.
|
||||||
|
pub fn iou(bbox: &BBox<Ltwh>, candidates: &[BBox<Ltwh>]) -> Array1<f32> {
|
||||||
|
let bbox_area = bbox.width() * bbox.height();
|
||||||
|
|
||||||
|
candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c_ltwh| {
|
||||||
|
let b1 = bbox.as_ltrb();
|
||||||
|
let b2 = c_ltwh.as_ltrb();
|
||||||
|
|
||||||
|
let i_xmin = b1.left().max(b2.left());
|
||||||
|
let i_ymin = b1.top().max(b2.top());
|
||||||
|
|
||||||
|
let i_xmax = b1.right().min(b2.right());
|
||||||
|
let i_ymax = b1.bottom().min(b2.bottom());
|
||||||
|
|
||||||
|
let intersection_area = ((i_xmax - i_xmin).max(0.0) * (i_ymax - i_ymin).max(0.0));
|
||||||
|
let candidate_area = c_ltwh.width() * c_ltwh.height();
|
||||||
|
|
||||||
|
intersection_area / (bbox_area + candidate_area - intersection_area)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// An intersection over union distance metric.
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// tracks : List[deep_sort.track.Track]
|
||||||
|
/// A list of tracks.
|
||||||
|
/// detections : List[deep_sort.detection.Detection]
|
||||||
|
/// A list of detections.
|
||||||
|
/// track_indices : Optional[List[int]]
|
||||||
|
/// A list of indices to tracks that should be matched. Defaults to
|
||||||
|
/// all `tracks`.
|
||||||
|
/// detection_indices : Optional[List[int]]
|
||||||
|
/// A list of indices to detections that should be matched. Defaults
|
||||||
|
/// to all `detections`.
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns a cost matrix of shape
|
||||||
|
/// len(track_indices), len(detection_indices) where entry (i, j) is
|
||||||
|
/// `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||||
|
///
|
||||||
|
pub fn iou_cost(tracks: &[Track], detections: &[Detection], track_indices: &[usize], detection_indices: &[usize]) -> Array2<f32> {
|
||||||
|
let track_n = track_indices.len();
|
||||||
|
let det_n = detection_indices.len();
|
||||||
|
let n = track_n.max(det_n);
|
||||||
|
|
||||||
|
let mut cost_matrix = Array2::from_elem((n, n), 1.0);
|
||||||
|
|
||||||
|
for (row, &track_idx) in track_indices.iter().enumerate() {
|
||||||
|
let track = &tracks[track_idx];
|
||||||
|
|
||||||
|
if track.time_since_update > 1 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let bbox = track.bbox().as_ltwh();
|
||||||
|
let candidates: Vec<_> = detection_indices
|
||||||
|
.iter()
|
||||||
|
.map(|&i| detections[i].bbox.clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
|
||||||
|
cost_matrix.slice_mut(s![row, ..det_n]).assign(&(1.0 - iou(&bbox, &candidates)));
|
||||||
|
}
|
||||||
|
|
||||||
|
cost_matrix
|
||||||
|
}
|
318
src/sort/kalman_filter.rs
Normal file
318
src/sort/kalman_filter.rs
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
use crate::sort::{BBox, Xyah};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use ndarray_linalg::cholesky::*;
|
||||||
|
use ndarray_linalg::triangular::*;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||||
|
/// freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||||
|
/// function and used as Mahalanobis gating threshold.
|
||||||
|
///
|
||||||
|
pub const CHI_2_INV_95: [f32; 9] = [
|
||||||
|
3.8415, // 1
|
||||||
|
5.9915, // 2
|
||||||
|
7.8147, // 3
|
||||||
|
9.4877, // 4
|
||||||
|
11.070, // 5
|
||||||
|
12.592, // 6
|
||||||
|
14.067, // 7
|
||||||
|
15.507, // 8
|
||||||
|
16.919, // 9
|
||||||
|
];
|
||||||
|
|
||||||
|
/// A simple Kalman filter for tracking bounding boxes in image space.
|
||||||
|
///
|
||||||
|
/// The 8-dimensional state space
|
||||||
|
///
|
||||||
|
/// x, y, a, h, vx, vy, va, vh
|
||||||
|
///
|
||||||
|
/// contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||||
|
/// and their respective velocities.
|
||||||
|
///
|
||||||
|
/// Object motion follows a constant velocity model. The bounding box location
|
||||||
|
/// (x, y, a, h) is taken as direct observation of the state space (linear
|
||||||
|
/// observation model).
|
||||||
|
///
|
||||||
|
/// update_mat: Array2<f32> of shape (4, 8)
|
||||||
|
///
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct KalmanFilter {
|
||||||
|
ndim: usize,
|
||||||
|
dt: f32,
|
||||||
|
motion_mat: Array2<f32>,
|
||||||
|
update_mat: Array2<f32>,
|
||||||
|
std_weight_position: f32,
|
||||||
|
std_weight_velocity: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KalmanFilter {
|
||||||
|
fn default() -> Self {
|
||||||
|
let (ndim, dt) = (4, 1.);
|
||||||
|
|
||||||
|
// Create Kalman filter model matrices.
|
||||||
|
let mut motion_mat = Array2::eye(2 * ndim);
|
||||||
|
|
||||||
|
for i in 0..ndim {
|
||||||
|
motion_mat[(i, ndim + i)] = dt;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut update_mat = Array2::zeros((ndim, 2 * ndim));
|
||||||
|
|
||||||
|
for i in 0..ndim {
|
||||||
|
update_mat[(i, i)] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Motion and observation uncertainty are chosen relative to the current
|
||||||
|
// state estimate. These weights control the amount of uncertainty in
|
||||||
|
// the model. This is a bit hacky.
|
||||||
|
let std_weight_position = 1.0 / 20.0;
|
||||||
|
let std_weight_velocity = 1.0 / 160.0;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
ndim,
|
||||||
|
dt,
|
||||||
|
motion_mat,
|
||||||
|
update_mat,
|
||||||
|
std_weight_position,
|
||||||
|
std_weight_velocity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KalmanFilter {
|
||||||
|
///
|
||||||
|
/// Create track from unassociated measurement.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// measurement : ndarray
|
||||||
|
/// Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||||
|
/// aspect ratio a, and height h.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// (ndarray, ndarray)
|
||||||
|
/// Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||||
|
/// dimensional) of the new track. Unobserved velocities are initialized
|
||||||
|
/// to 0 mean.
|
||||||
|
///
|
||||||
|
pub fn initiate(&self, measurement: BBox<Xyah>) -> (Array1<f32>, Array2<f32>) {
|
||||||
|
let mut mean = Array1::zeros((8,));
|
||||||
|
mean.slice_mut(s![..4]).assign(&measurement.as_view());
|
||||||
|
|
||||||
|
let std = arr1(&[
|
||||||
|
2.0 * self.std_weight_position * measurement.height(),
|
||||||
|
2.0 * self.std_weight_position * measurement.height(),
|
||||||
|
1.0e-2,
|
||||||
|
2.0 * self.std_weight_position * measurement.height(),
|
||||||
|
10.0 * self.std_weight_velocity * measurement.height(),
|
||||||
|
10.0 * self.std_weight_velocity * measurement.height(),
|
||||||
|
1.0e-5,
|
||||||
|
10.0 * self.std_weight_velocity * measurement.height(),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let covariance = Array2::from_diag(&(&std * &std));
|
||||||
|
|
||||||
|
(mean, covariance)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run Kalman filter prediction step.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// mean : ndarray
|
||||||
|
/// The 8 dimensional mean vector of the object state at the previous
|
||||||
|
/// time step.
|
||||||
|
/// covariance : ndarray
|
||||||
|
/// The 8x8 dimensional covariance matrix of the object state at the
|
||||||
|
/// previous time step.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// (ndarray, ndarray)
|
||||||
|
/// Returns the mean vector and covariance matrix of the predicted
|
||||||
|
/// state. Unobserved velocities are initialized to 0 mean.
|
||||||
|
///
|
||||||
|
pub fn predict(&self, mean: ArrayView1<'_, f32>, covariance: ArrayView2<'_, f32>) -> (Array1<f32>, Array2<f32>) {
|
||||||
|
let std = arr1(&[
|
||||||
|
// posititon
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
1e-2,
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
|
||||||
|
// velocity
|
||||||
|
self.std_weight_velocity * mean[3],
|
||||||
|
self.std_weight_velocity * mean[3],
|
||||||
|
1e-5,
|
||||||
|
self.std_weight_velocity * mean[3],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let motion_cov = Array2::from_diag(&(&std * &std));
|
||||||
|
let mean = self.motion_mat.dot(&mean);
|
||||||
|
|
||||||
|
let covariance = self.motion_mat.dot(&covariance).dot(&self.motion_mat.t());
|
||||||
|
|
||||||
|
(mean, covariance + motion_cov)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project state distribution to measurement space.
|
||||||
|
//
|
||||||
|
// Parameters
|
||||||
|
// ----------
|
||||||
|
// mean : ndarray
|
||||||
|
// The state's mean vector (8 dimensional array).
|
||||||
|
// covariance : ndarray
|
||||||
|
// The state's covariance matrix (8x8 dimensional).
|
||||||
|
//
|
||||||
|
// Returns
|
||||||
|
// -------
|
||||||
|
// (ndarray, ndarray)
|
||||||
|
// Returns the projected mean and covariance matrix of the given state
|
||||||
|
// estimate.
|
||||||
|
//
|
||||||
|
fn project(&self, mean: ArrayView1<'_, f32>, covariance: ArrayView2<'_, f32>) -> (Array1<f32>, Array2<f32>) {
|
||||||
|
let std = arr1(&[
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
1e-1,
|
||||||
|
self.std_weight_position * mean[3],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let innovation_cov = Array2::from_diag(&(&std * &std));
|
||||||
|
let mean = self.update_mat.dot(&mean);
|
||||||
|
let covariance = self.update_mat.dot(&covariance).dot(&self.update_mat.t());
|
||||||
|
|
||||||
|
(mean, covariance + innovation_cov)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run Kalman filter correction step.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// mean : ndarray
|
||||||
|
/// The predicted state's mean vector (8 dimensional).
|
||||||
|
/// covariance : ndarray
|
||||||
|
/// The state's covariance matrix (8x8 dimensional).
|
||||||
|
/// measurement : ndarray
|
||||||
|
/// The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||||
|
/// is the center position, a the aspect ratio, and h the height of the
|
||||||
|
/// bounding box.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// (ndarray, ndarray)
|
||||||
|
/// Returns the measurement-corrected state distribution.
|
||||||
|
///
|
||||||
|
pub fn update(
|
||||||
|
&self,
|
||||||
|
mean: ArrayView1<'_, f32>,
|
||||||
|
covariance: ArrayView2<'_, f32>,
|
||||||
|
measurement: ArrayView1<'_, f32>,
|
||||||
|
) -> (Array1<f32>, Array2<f32>) {
|
||||||
|
let (projected_mean, projected_cov) = self.project(mean, covariance);
|
||||||
|
|
||||||
|
// chol shape (4, 4)
|
||||||
|
let chol = projected_cov.factorizec(UPLO::Lower).unwrap();
|
||||||
|
|
||||||
|
// (8, 4)
|
||||||
|
let mut kalman_gain = covariance.dot(&self.update_mat.t());
|
||||||
|
|
||||||
|
for mut axis in kalman_gain.axis_iter_mut(Axis(0)) {
|
||||||
|
// axis shape (4, )
|
||||||
|
chol.solvec_inplace(&mut axis).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let innovation = &measurement - &projected_mean;
|
||||||
|
let new_mean = &mean + &innovation.dot(&kalman_gain.t());
|
||||||
|
let new_covariance = &covariance - &kalman_gain.dot(&projected_cov).dot(&kalman_gain.t());
|
||||||
|
|
||||||
|
(new_mean, new_covariance)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute gating distance between state distribution and measurements.
|
||||||
|
///
|
||||||
|
/// A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||||
|
/// `only_position` is False, the chi-square distribution has 4 degrees of
|
||||||
|
/// freedom, otherwise 2.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// mean : ndarray
|
||||||
|
/// Mean vector over the state distribution (8 dimensional).
|
||||||
|
/// covariance : ndarray
|
||||||
|
/// Covariance of the state distribution (8x8 dimensional).
|
||||||
|
/// measurements : ndarray
|
||||||
|
/// An Nx4 dimensional matrix of N measurements, each in
|
||||||
|
/// format (x, y, a, h) where (x, y) is the bounding box center
|
||||||
|
/// position, a the aspect ratio, and h the height.
|
||||||
|
/// only_position : Optional[bool]
|
||||||
|
/// If True, distance computation is done with respect to the bounding
|
||||||
|
/// box center position only.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns an array of length N, where the i-th element contains the
|
||||||
|
/// squared Mahalanobis distance between (mean, covariance) and
|
||||||
|
/// `measurements[i]`.
|
||||||
|
///
|
||||||
|
pub fn gating_distance(
|
||||||
|
&self,
|
||||||
|
mean: ArrayView1<'_, f32>,
|
||||||
|
covariance: ArrayView2<'_, f32>,
|
||||||
|
measurements: ArrayView2<'_, f32>,
|
||||||
|
only_position: bool,
|
||||||
|
) -> Array1<f32> {
|
||||||
|
let (mean, covariance) = self.project(mean, covariance);
|
||||||
|
|
||||||
|
let (mean, covariance, measurements) = if only_position {
|
||||||
|
(
|
||||||
|
mean.slice(s!(..2)),
|
||||||
|
covariance.slice(s!(..2, ..2)),
|
||||||
|
measurements.slice(s!(.., ..2)),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(mean.view(), covariance.view(), measurements.view())
|
||||||
|
};
|
||||||
|
|
||||||
|
let d = &measurements - &mean;
|
||||||
|
|
||||||
|
let cholesky_lower = covariance.cholesky(UPLO::Lower).unwrap();
|
||||||
|
let z = cholesky_lower
|
||||||
|
.solve_triangular_into(UPLO::Lower, Diag::NonUnit, d.reversed_axes())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
(&z * &z).sum_axis(Axis(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kalman() {
|
||||||
|
let mut kl = KalmanFilter::default();
|
||||||
|
let (m, c) = kl.initiate(BBox::xyah(128.0, 128.0, 0.5, 64.0));
|
||||||
|
|
||||||
|
// let target_mean = &[128. , 128. , 0.5, 64. , 0. , 0. , 0. , 0. ]), array([[67.2 , 0. , 0. , 0. , 16. , 0. , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 67.2 , 0. , 0. , 0. , 16. , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 0. , 0.0002, 0. , 0. , 0. , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 0. , 0. , 67.2 , 0. , 0. , 0. ,
|
||||||
|
// 16. ],
|
||||||
|
// [16. , 0. , 0. , 0. , 16.16 , 0. , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 16. , 0. , 0. , 0. , 16.16 , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
||||||
|
// 0. ],
|
||||||
|
// [ 0. , 0. , 0. , 16. , 0. , 0. , 0. ,
|
||||||
|
// 16.16 ]]))
|
||||||
|
|
||||||
|
// println!("{:?}", kl.predict(m.view(),c.view()))
|
||||||
|
|
||||||
|
let (m, c) = kl.update(m.view(),c.view(),aview1(&[192.0, 192.0, 0.5, 68.0]));
|
||||||
|
|
||||||
|
println!("{:?}", kl.gating_distance(m.view(), c.view(), aview2(&[[256.0, 256.0, 0.5, 80.0]]), false));
|
||||||
|
|
||||||
|
}
|
269
src/sort/linear_assignment.rs
Normal file
269
src/sort/linear_assignment.rs
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use crate::sort::{Track, Detection, KalmanFilter};
|
||||||
|
|
||||||
|
pub const INFTY_COST: f32 = 1e+5;
|
||||||
|
|
||||||
|
|
||||||
|
/// Solve linear assignment problem.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||||
|
/// The distance metric is given a list of tracks and detections as well as
|
||||||
|
/// a list of N track indices and M detection indices. The metric should
|
||||||
|
/// return the NxM dimensional cost matrix, where element (i, j) is the
|
||||||
|
/// association cost between the i-th track in the given track indices and
|
||||||
|
/// the j-th detection in the given detection_indices.
|
||||||
|
/// max_distance : float
|
||||||
|
/// Gating threshold. Associations with cost larger than this value are
|
||||||
|
/// disregarded.
|
||||||
|
/// tracks : List[track.Track]
|
||||||
|
/// A list of predicted tracks at the current time step.
|
||||||
|
/// detections : List[detection.Detection]
|
||||||
|
/// A list of detections at the current time step.
|
||||||
|
/// track_indices : List[int]
|
||||||
|
/// List of track indices that maps rows in `cost_matrix` to tracks in
|
||||||
|
/// `tracks` (see description above).
|
||||||
|
/// detection_indices : List[int]
|
||||||
|
/// List of detection indices that maps columns in `cost_matrix` to
|
||||||
|
/// detections in `detections` (see description above).
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// (List[(int, int)], List[int], List[int])
|
||||||
|
/// Returns a tuple with the following three entries:
|
||||||
|
/// * A list of matched track and detection indices.
|
||||||
|
/// * A list of unmatched track indices.
|
||||||
|
/// * A list of unmatched detection indices.
|
||||||
|
///
|
||||||
|
pub fn min_cost_matching<D: Fn(&[Track], &[Detection], &[usize], &[usize]) -> Array2<f32>>(
|
||||||
|
distance_metric: &D,
|
||||||
|
max_distance: f32,
|
||||||
|
tracks: &[Track],
|
||||||
|
detections: &[Detection],
|
||||||
|
track_indices: Option<Vec<usize>>,
|
||||||
|
detection_indices: Option<Vec<usize>>
|
||||||
|
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
|
||||||
|
let track_indices = track_indices.unwrap_or_else(||(0..tracks.len()).collect());
|
||||||
|
let detection_indices = detection_indices.unwrap_or_else(||(0..detections.len()).collect());
|
||||||
|
|
||||||
|
if detection_indices.is_empty() || track_indices.is_empty() {
|
||||||
|
return (vec![], track_indices, detection_indices); // Nothing to match.
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut cost_matrix = distance_metric(tracks, detections, &track_indices, &detection_indices);
|
||||||
|
cost_matrix.mapv_inplace(|x| if x > max_distance { max_distance + 1.0e-5 } else { x });
|
||||||
|
|
||||||
|
let mut weights = munkres::WeightMatrix::from_row_vec(cost_matrix.nrows(), cost_matrix.iter().copied().collect());
|
||||||
|
let indices = munkres::solve_assignment(&mut weights).unwrap();
|
||||||
|
|
||||||
|
let (mut matches, mut unmatched_tracks, mut unmatched_detections) = (vec![], vec![], vec![]);
|
||||||
|
|
||||||
|
for (idx, &detection_idx) in detection_indices.iter().enumerate() {
|
||||||
|
let mut present = false;
|
||||||
|
|
||||||
|
for pos in indices.iter() {
|
||||||
|
if idx == pos.column && pos.row < track_indices.len() {
|
||||||
|
present = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !present {
|
||||||
|
unmatched_detections.push(detection_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (idx, &track_idx) in track_indices.iter().enumerate() {
|
||||||
|
let mut present = false;
|
||||||
|
|
||||||
|
for pos in indices.iter() {
|
||||||
|
if idx == pos.row && pos.column < detection_indices.len() {
|
||||||
|
present = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !present {
|
||||||
|
unmatched_tracks.push(track_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for pos in indices.into_iter() {
|
||||||
|
if pos.row < track_indices.len() &&
|
||||||
|
pos.column < detection_indices.len() {
|
||||||
|
|
||||||
|
let track_idx = track_indices[pos.row];
|
||||||
|
let detection_idx = detection_indices[pos.column];
|
||||||
|
|
||||||
|
if cost_matrix[(pos.row, pos.column)] > max_distance {
|
||||||
|
unmatched_tracks.push(track_idx);
|
||||||
|
unmatched_detections.push(detection_idx);
|
||||||
|
} else {
|
||||||
|
matches.push((track_idx, detection_idx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(matches, unmatched_tracks, unmatched_detections)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run matching cascade.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||||
|
/// The distance metric is given a list of tracks and detections as well as
|
||||||
|
/// a list of N track indices and M detection indices. The metric should
|
||||||
|
/// return the NxM dimensional cost matrix, where element (i, j) is the
|
||||||
|
/// association cost between the i-th track in the given track indices and
|
||||||
|
/// the j-th detection in the given detection indices.
|
||||||
|
/// max_distance : float
|
||||||
|
/// Gating threshold. Associations with cost larger than this value are
|
||||||
|
/// disregarded.
|
||||||
|
/// cascade_depth: int
|
||||||
|
/// The cascade depth, should be se to the maximum track age.
|
||||||
|
/// tracks : List[track.Track]
|
||||||
|
/// A list of predicted tracks at the current time step.
|
||||||
|
/// detections : List[detection.Detection]
|
||||||
|
/// A list of detections at the current time step.
|
||||||
|
/// track_indices : Optional[List[int]]
|
||||||
|
/// List of track indices that maps rows in `cost_matrix` to tracks in
|
||||||
|
/// `tracks` (see description above). Defaults to all tracks.
|
||||||
|
/// detection_indices : Optional[List[int]]
|
||||||
|
/// List of detection indices that maps columns in `cost_matrix` to
|
||||||
|
/// detections in `detections` (see description above). Defaults to all
|
||||||
|
/// detections.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// (List[(int, int)], List[int], List[int])
|
||||||
|
/// Returns a tuple with the following three entries:
|
||||||
|
/// * A list of matched track and detection indices.
|
||||||
|
/// * A list of unmatched track indices.
|
||||||
|
/// * A list of unmatched detection indices.
|
||||||
|
///
|
||||||
|
pub fn matching_cascade<D: Fn(&[Track], &[Detection], &[usize], &[usize]) -> Array2<f32>>(
|
||||||
|
distance_metric: &D,
|
||||||
|
max_distance: f32,
|
||||||
|
cascade_depth: i32,
|
||||||
|
tracks: &[Track],
|
||||||
|
detections: &[Detection],
|
||||||
|
track_indices: Option<Vec<usize>>,
|
||||||
|
detection_indices: Option<Vec<usize>>) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>)
|
||||||
|
{
|
||||||
|
let track_indices = track_indices.unwrap_or_else(||(0..tracks.len()).collect());
|
||||||
|
let detection_indices = detection_indices.unwrap_or_else(||(0..detections.len()).collect());
|
||||||
|
let mut unmatched_detections = detection_indices.clone();
|
||||||
|
let mut matches = vec![];
|
||||||
|
|
||||||
|
for level in 0..cascade_depth {
|
||||||
|
if unmatched_detections.is_empty() { // No detections left
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let track_indices_l: Vec<_> = track_indices
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.filter(|&idx|tracks[idx].time_since_update == 1 + level)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if track_indices_l.is_empty() { // Nothing to match at this level
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut matches_l, _, unmatched_detections_new) = min_cost_matching(
|
||||||
|
distance_metric, max_distance, tracks, detections,
|
||||||
|
Some(track_indices_l), Some(unmatched_detections.clone()));
|
||||||
|
|
||||||
|
unmatched_detections = unmatched_detections_new;
|
||||||
|
|
||||||
|
matches.append(&mut matches_l);
|
||||||
|
}
|
||||||
|
|
||||||
|
let track_indices_set: HashSet<_> = track_indices.into_iter().collect();
|
||||||
|
let matches_track_indices_set: HashSet<_> = matches.iter().map(|&(k, _)|k).collect();
|
||||||
|
|
||||||
|
let unmatched_tracks: Vec<usize> = track_indices_set.difference(&matches_track_indices_set)
|
||||||
|
.copied()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
(matches, unmatched_tracks, unmatched_detections)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invalidate infeasible entries in cost matrix based on the state distributions obtained by Kalman filtering.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// kf : The Kalman filter.
|
||||||
|
/// cost_matrix : ndarray
|
||||||
|
/// The NxM dimensional cost matrix, where N is the number of track indices
|
||||||
|
/// and M is the number of detection indices, such that entry (i, j) is the
|
||||||
|
/// association cost between `tracks[track_indices[i]]` and
|
||||||
|
/// `detections[detection_indices[j]]`.
|
||||||
|
/// tracks : List[track.Track]
|
||||||
|
/// A list of predicted tracks at the current time step.
|
||||||
|
/// detections : List[detection.Detection]
|
||||||
|
/// A list of detections at the current time step.
|
||||||
|
/// track_indices : List[int]
|
||||||
|
/// List of track indices that maps rows in `cost_matrix` to tracks in
|
||||||
|
/// `tracks` (see description above).
|
||||||
|
/// detection_indices : List[int]
|
||||||
|
/// List of detection indices that maps columns in `cost_matrix` to
|
||||||
|
/// detections in `detections` (see description above).
|
||||||
|
/// gated_cost : Optional[float]
|
||||||
|
/// Entries in the cost matrix corresponding to infeasible associations are
|
||||||
|
/// set this value. Defaults to a very large value.
|
||||||
|
/// only_position : Optional[bool]
|
||||||
|
/// If True, only the x, y position of the state distribution is considered
|
||||||
|
/// during gating. Defaults to False.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns the modified cost matrix.
|
||||||
|
///
|
||||||
|
pub fn gate_cost_matrix(
|
||||||
|
kf: &KalmanFilter,
|
||||||
|
mut cost_matrix: ArrayViewMut2<'_, f32>,
|
||||||
|
tracks: &[Track],
|
||||||
|
detections: &[Detection],
|
||||||
|
track_indices: &[usize],
|
||||||
|
detection_indices: &[usize],
|
||||||
|
gated_cost: Option<f32>,
|
||||||
|
only_position: Option<bool>)
|
||||||
|
{
|
||||||
|
let gated_cost = gated_cost.unwrap_or(INFTY_COST);
|
||||||
|
let only_position = only_position.unwrap_or(false);
|
||||||
|
let gating_dim = if only_position {1} else {3}; // indexes for 2 and 4 dims respectivly
|
||||||
|
let gating_threshold = crate::sort::kalman_filter::CHI_2_INV_95[gating_dim];
|
||||||
|
|
||||||
|
let mut measurements: Array2<f32> = unsafe { Array2::uninitialized((detection_indices.len(), 4)) };
|
||||||
|
|
||||||
|
for (mut row, &idx) in measurements.axis_iter_mut(Axis(0)).zip(detection_indices.iter()) {
|
||||||
|
let bbox = &detections[idx].bbox.as_xyah();
|
||||||
|
|
||||||
|
row.assign(&bbox.as_view());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (row, &track_idx) in track_indices.iter().enumerate() {
|
||||||
|
let track = &tracks[track_idx];
|
||||||
|
let gating_distance = kf.gating_distance(
|
||||||
|
track.mean(), track.covariance(), measurements.view(), only_position);
|
||||||
|
|
||||||
|
let mut axis = cost_matrix
|
||||||
|
.index_axis_mut(Axis(0), row);
|
||||||
|
|
||||||
|
for (idx, val) in axis.indexed_iter_mut() {
|
||||||
|
if idx >= gating_distance.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if gating_distance[idx] > gating_threshold {
|
||||||
|
*val = gated_cost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
246
src/sort/mod.rs
Normal file
246
src/sort/mod.rs
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
pub mod detection;
|
||||||
|
pub mod iou_matching;
|
||||||
|
pub mod kalman_filter;
|
||||||
|
pub mod nn_matching;
|
||||||
|
pub mod linear_assignment;
|
||||||
|
pub mod tracker;
|
||||||
|
pub mod track;
|
||||||
|
|
||||||
|
pub use detection::Detection;
|
||||||
|
pub use kalman_filter::KalmanFilter;
|
||||||
|
pub use tracker::Tracker;
|
||||||
|
pub use track::{Track, TrackState};
|
||||||
|
pub use nn_matching::*;
|
||||||
|
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
|
||||||
|
pub trait BBoxFormat: std::fmt::Debug {}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct Ltwh;
|
||||||
|
impl BBoxFormat for Ltwh {}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct Xyah;
|
||||||
|
impl BBoxFormat for Xyah {}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct Ltrb;
|
||||||
|
impl BBoxFormat for Ltrb {}
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BBox<F: BBoxFormat>([f32; 4], PhantomData<F>);
|
||||||
|
impl<F: BBoxFormat> BBox<F> {
|
||||||
|
#[inline]
|
||||||
|
pub fn as_view(&self) -> ArrayView1<'_, f32> {
|
||||||
|
aview1(&self.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn xyah(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox(
|
||||||
|
[x1, x2, x3, x4],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Ltwh> {
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn left(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn top(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn width(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn height(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_xyah(&self) -> BBox<Xyah> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_ltrb(&self) -> BBox<Ltrb> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn ltwh(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox(
|
||||||
|
[x1, x2, x3, x4],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
pub fn ltrb(x1: f32, x2: f32, x3: f32, x4: f32) -> Self {
|
||||||
|
BBox(
|
||||||
|
[x1, x2, x3, x4],
|
||||||
|
Default::default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn as_ltwh(&self) -> BBox<Ltwh> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn left(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn top(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn right(&self) -> f32 {
|
||||||
|
self.0[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn bottom(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BBox<Xyah> {
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_ltrb(&self) -> BBox<Ltrb> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_ltwh(&self) -> BBox<Ltwh> {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cx(&self) -> f32 {
|
||||||
|
self.0[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn cy(&self) -> f32 {
|
||||||
|
self.0[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn height(&self) -> f32 {
|
||||||
|
self.0[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> From<&'a BBox<Ltwh>> for BBox<Xyah> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltwh>) -> Self {
|
||||||
|
Self([
|
||||||
|
v.0[0] + v.0[2] / 2.0,
|
||||||
|
v.0[1] + v.0[3] / 2.0,
|
||||||
|
v.0[2] / v.0[3],
|
||||||
|
v.0[3],
|
||||||
|
], Default::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> From<&'a BBox<Ltwh>> for BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltwh>) -> Self {
|
||||||
|
Self([
|
||||||
|
v.0[0],
|
||||||
|
v.0[1],
|
||||||
|
v.0[2] + v.0[0],
|
||||||
|
v.0[3] + v.0[1],
|
||||||
|
], Default::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> From<&'a BBox<Ltrb>> for BBox<Ltwh> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Ltrb>) -> Self {
|
||||||
|
Self([
|
||||||
|
v.0[0],
|
||||||
|
v.0[1],
|
||||||
|
v.0[2] - v.0[0],
|
||||||
|
v.0[3] - v.0[1],
|
||||||
|
], Default::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> From<&'a BBox<Xyah>> for BBox<Ltwh> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xyah>) -> Self {
|
||||||
|
let height = v.0[3];
|
||||||
|
let width = v.0[2] * height;
|
||||||
|
|
||||||
|
Self([
|
||||||
|
v.0[0] - width / 2.0,
|
||||||
|
v.0[1] - height / 2.0,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
], Default::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> From<&'a BBox<Xyah>> for BBox<Ltrb> {
|
||||||
|
#[inline]
|
||||||
|
fn from(v: &'a BBox<Xyah>) -> Self {
|
||||||
|
(&v.as_ltwh()).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait DistanceMetric {
|
||||||
|
|
||||||
|
/// Getting the matching threshold
|
||||||
|
///
|
||||||
|
fn matching_threshold(&self) -> f32;
|
||||||
|
|
||||||
|
/// Update the distance metric with new data.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// features : ndarray
|
||||||
|
/// An NxM matrix of N features of dimensionality M.
|
||||||
|
/// targets : ndarray
|
||||||
|
/// An integer array of associated target identities.
|
||||||
|
/// active_targets : List[int]
|
||||||
|
/// A list of targets that are currently present in the scene.
|
||||||
|
///
|
||||||
|
fn partial_fit(&mut self, features: Vec<Array1<f32>>, targets: Vec<i32>, active_targets: Vec<i32>);
|
||||||
|
|
||||||
|
/// Compute distance between features and targets.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// features : ndarray
|
||||||
|
/// An NxM matrix of N features of dimensionality M.
|
||||||
|
/// targets : List[int]
|
||||||
|
/// A list of targets to match the given `features` against.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns a cost matrix of shape len(targets), len(features), where
|
||||||
|
/// element (i, j) contains the closest squared distance between
|
||||||
|
/// `targets[i]` and `features[j]`.
|
||||||
|
///
|
||||||
|
fn distance(&self, features: ArrayView2<'_, f32>, targets: Vec<i32>) -> Array2<f32>;
|
||||||
|
}
|
249
src/sort/nn_matching.rs
Normal file
249
src/sort/nn_matching.rs
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
use std::collections::{VecDeque, HashMap};
|
||||||
|
use ndarray::prelude::*;
|
||||||
|
use crate::sort::DistanceMetric;
|
||||||
|
use crate::sort::linear_assignment::INFTY_COST;
|
||||||
|
|
||||||
|
// Compute pair-wise squared distance between points in `a` and `b`.
|
||||||
|
//
|
||||||
|
// Parameters
|
||||||
|
// ----------
|
||||||
|
// a : array_like
|
||||||
|
// An NxM matrix of N samples of dimensionality M.
|
||||||
|
// b : array_like
|
||||||
|
// An LxM matrix of L samples of dimensionality M.
|
||||||
|
//
|
||||||
|
// Returns
|
||||||
|
// -------
|
||||||
|
// ndarray
|
||||||
|
// Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||||
|
// contains the squared distance between `a[i]` and `b[j]`.
|
||||||
|
//
|
||||||
|
fn pdist(a: ArrayView2<'_, f32>, b: ArrayView2<'_, f32>) -> Array2<f32> {
|
||||||
|
if a.is_empty() || b.is_empty() {
|
||||||
|
return Array2::zeros((a.len(), b.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (a2, b2) = (
|
||||||
|
(&a * &a).sum_axis(Axis(1)).insert_axis(Axis(1)),
|
||||||
|
(&b * &b).sum_axis(Axis(1)).insert_axis(Axis(0))
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut r2 = -2.0 * a.dot(&b.t()) + a2 + b2;
|
||||||
|
|
||||||
|
r2.mapv_inplace(|x| if x < 0.0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
x
|
||||||
|
});
|
||||||
|
|
||||||
|
r2
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute pair-wise cosine distance between points in `a` and `b`.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// a : array_like
|
||||||
|
/// An NxM matrix of N samples of dimensionality M.
|
||||||
|
/// b : array_like
|
||||||
|
/// An LxM matrix of L samples of dimensionality M.
|
||||||
|
/// data_is_normalized : Optional[bool]
|
||||||
|
/// If True, assumes rows in a and b are unit length vectors.
|
||||||
|
/// Otherwise, a and b are explicitly normalized to lenght 1.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||||
|
/// contains the squared distance between `a[i]` and `b[j]`.
|
||||||
|
///
|
||||||
|
fn cosine_distance(a: ArrayView2<'_, f32>, b: ArrayView2<'_, f32>, data_is_normalized: bool) -> Array2<f32> {
|
||||||
|
if data_is_normalized {
|
||||||
|
-a.dot(&b.t()) + 1.0
|
||||||
|
} else {
|
||||||
|
let length_a = a.map_axis(Axis(1), |x|x.fold(0.0, |a, x|a + x*x).sqrt());
|
||||||
|
let length_b = b.map_axis(Axis(1), |x|x.fold(0.0, |a, x|a + x*x).sqrt());
|
||||||
|
|
||||||
|
let a = &a / &length_a.insert_axis(Axis(1));
|
||||||
|
let b = &b / &length_b.insert_axis(Axis(1));
|
||||||
|
|
||||||
|
-a.dot(&b.t()) + 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_test() {
|
||||||
|
println!("{:?}", cosine_distance(aview2(&[[1.0f32,2.,3.,4.,5.,6.]]), aview2(&[[5.,6.,7.,8.,9.,10.]]), false));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function for nearest neighbor distance metric (Euclidean).
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// x : ndarray
|
||||||
|
/// A matrix of N row-vectors (sample points).
|
||||||
|
/// y : ndarray
|
||||||
|
/// A matrix of M row-vectors (query points).
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// A vector of length M that contains for each entry in `y` the
|
||||||
|
/// smallest Euclidean distance to a sample in `x`.
|
||||||
|
///
|
||||||
|
fn nn_euclidean_distance(x: ArrayView2<'_, f32>, y: ArrayView2<'_, f32>) -> Array1<f32> {
|
||||||
|
let distances = pdist(x, y);
|
||||||
|
|
||||||
|
distances.map_axis(Axis(0), |view|view.fold(f32::MAX, |a, &x| if x < a { x } else { a }))
|
||||||
|
.mapv_into(|x|x.max(0.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function for nearest neighbor distance metric (cosine).
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// x : ndarray
|
||||||
|
/// A matrix of N row-vectors (sample points).
|
||||||
|
/// y : ndarray
|
||||||
|
/// A matrix of M row-vectors (query points).
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// A vector of length M that contains for each entry in `y` the
|
||||||
|
/// smallest cosine distance to a sample in `x`.
|
||||||
|
///
|
||||||
|
fn nn_cosine_distance(x: ArrayView2<'_, f32>, y: ArrayView2<'_, f32>) -> Array1<f32> {
|
||||||
|
let distances = cosine_distance(x, y, false);
|
||||||
|
|
||||||
|
distances.map_axis(Axis(0), |view|view.fold(f32::MAX, |a, &x| if x < a { x } else { a }))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum NearestNeighborMetricKind {
|
||||||
|
EuclideanDistance,
|
||||||
|
CosineDistance,
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// A nearest neighbor distance metric that, for each target, returns
|
||||||
|
/// the closest distance to any sample that has been observed so far.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// metric : str
|
||||||
|
/// Either "euclidean" or "cosine".
|
||||||
|
/// matching_threshold: float
|
||||||
|
/// The matching threshold. Samples with larger distance are considered an
|
||||||
|
/// invalid match.
|
||||||
|
/// budget : Optional[int]
|
||||||
|
/// If not None, fix samples per class to at most this number. Removes
|
||||||
|
/// the oldest samples when the budget is reached.
|
||||||
|
///
|
||||||
|
/// Attributes
|
||||||
|
/// ----------
|
||||||
|
/// samples : Dict[int -> List[ndarray]]
|
||||||
|
/// A dictionary that maps from target identities to the list of samples
|
||||||
|
/// that have been observed so far.
|
||||||
|
///
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct NearestNeighborDistanceMetric {
|
||||||
|
metric_kind: NearestNeighborMetricKind,
|
||||||
|
matching_threshold: f32,
|
||||||
|
budget: Option<usize>,
|
||||||
|
samples: HashMap<i32, VecDeque<Array1<f32>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NearestNeighborDistanceMetric {
|
||||||
|
pub fn new(metric_kind: NearestNeighborMetricKind, matching_threshold: f32, budget: Option<usize>) -> Self {
|
||||||
|
Self {
|
||||||
|
metric_kind,
|
||||||
|
matching_threshold,
|
||||||
|
budget,
|
||||||
|
samples: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistanceMetric for NearestNeighborDistanceMetric {
|
||||||
|
#[inline]
|
||||||
|
fn matching_threshold(&self) -> f32 {
|
||||||
|
self.matching_threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the distance metric with new data.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// features : ndarray
|
||||||
|
/// An NxM matrix of N features of dimensionality M.
|
||||||
|
/// targets : ndarray
|
||||||
|
/// An integer array of associated target identities.
|
||||||
|
/// active_targets : List[int]
|
||||||
|
/// A list of targets that are currently present in the scene.
|
||||||
|
///
|
||||||
|
fn partial_fit(&mut self, features: Vec<Array1<f32>>, targets: Vec<i32>, active_targets: Vec<i32>) {
|
||||||
|
for (feature, target) in features.into_iter().zip(targets.into_iter()) {
|
||||||
|
let deque = self.samples
|
||||||
|
.entry(target)
|
||||||
|
.or_insert_with(VecDeque::new);
|
||||||
|
|
||||||
|
deque.push_front(feature);
|
||||||
|
|
||||||
|
if let Some(budget) = self.budget {
|
||||||
|
deque.truncate(budget);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_samples = active_targets
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|k| Some((k, self.samples.remove(&k)?)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
self.samples = new_samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute distance between features and targets.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// features : ndarray
|
||||||
|
/// An NxM matrix of N features of dimensionality M.
|
||||||
|
/// targets : List[int]
|
||||||
|
/// A list of targets to match the given `features` against.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// Returns a cost matrix of shape len(targets), len(features), where
|
||||||
|
/// element (i, j) contains the closest squared distance between
|
||||||
|
/// `targets[i]` and `features[j]`.
|
||||||
|
///
|
||||||
|
fn distance(&self, features: ArrayView2<'_, f32>, targets: Vec<i32>) -> Array2<f32> {
|
||||||
|
let ntargets = targets.len();
|
||||||
|
let nfeatures = features.nrows();
|
||||||
|
let n = nfeatures.max(ntargets);
|
||||||
|
|
||||||
|
let mut cost_matrix = Array2::from_elem((n, n), INFTY_COST);
|
||||||
|
|
||||||
|
for (i, target) in targets.into_iter().enumerate() {
|
||||||
|
let sample_features_deq = &self.samples[&target];
|
||||||
|
|
||||||
|
let mut sample_features = unsafe { Array::uninitialized((sample_features_deq.len(), features.ncols())) };
|
||||||
|
sample_features_deq
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(idx, arr)| sample_features.index_axis_mut(Axis(0), idx).assign(&arr));
|
||||||
|
|
||||||
|
cost_matrix
|
||||||
|
.slice_mut(s![i, ..nfeatures])
|
||||||
|
.assign(&match self.metric_kind {
|
||||||
|
NearestNeighborMetricKind::EuclideanDistance => nn_euclidean_distance(sample_features.view(), features.view()),
|
||||||
|
NearestNeighborMetricKind::CosineDistance => nn_cosine_distance(sample_features.view(), features.view()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
cost_matrix
|
||||||
|
}
|
||||||
|
}
|
198
src/sort/track.rs
Normal file
198
src/sort/track.rs
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
use ndarray::prelude::*;
|
||||||
|
use crate::sort::{Detection, KalmanFilter, BBox, Xyah};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Enumeration type for the single target track state. Newly created tracks are
|
||||||
|
/// classified as `tentative` until enough evidence has been collected. Then,
|
||||||
|
/// the track state is changed to `confirmed`. Tracks that are no longer alive
|
||||||
|
/// are classified as `deleted` to mark them for removal from the set of active
|
||||||
|
/// tracks.
|
||||||
|
///
|
||||||
|
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub enum TrackState {
|
||||||
|
Tentative,
|
||||||
|
Confirmed,
|
||||||
|
Deleted,
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// A single target track with state space `(x, y, a, h)` and associated
|
||||||
|
/// velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
||||||
|
/// aspect ratio and `h` is the height.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// mean : ndarray
|
||||||
|
/// Mean vector of the initial state distribution.
|
||||||
|
/// covariance : ndarray
|
||||||
|
/// Covariance matrix of the initial state distribution.
|
||||||
|
/// track_id : int
|
||||||
|
/// A unique track identifier.
|
||||||
|
/// n_init : int
|
||||||
|
/// Number of consecutive detections before the track is confirmed. The
|
||||||
|
/// track state is set to `Deleted` if a miss occurs within the first
|
||||||
|
/// `n_init` frames.
|
||||||
|
/// max_age : int
|
||||||
|
/// The maximum number of consecutive misses before the track state is
|
||||||
|
/// set to `Deleted`.
|
||||||
|
/// feature : Optional[ndarray]
|
||||||
|
/// Feature vector of the detection this track originates from. If not None,
|
||||||
|
/// this feature is added to the `features` cache.
|
||||||
|
///
|
||||||
|
/// Attributes
|
||||||
|
/// ----------
|
||||||
|
/// mean : ndarray
|
||||||
|
/// Mean vector of the initial state distribution.
|
||||||
|
/// covariance : ndarray
|
||||||
|
/// Covariance matrix of the initial state distribution.
|
||||||
|
/// track_id : int
|
||||||
|
/// A unique track identifier.
|
||||||
|
/// hits : int
|
||||||
|
/// Total number of measurement updates.
|
||||||
|
/// age : int
|
||||||
|
/// Total number of frames since first occurance.
|
||||||
|
/// time_since_update : int
|
||||||
|
/// Total number of frames since last measurement update.
|
||||||
|
/// state : TrackState
|
||||||
|
/// The current track state.
|
||||||
|
/// features : List[ndarray]
|
||||||
|
/// A cache of features. On each measurement update, the associated feature
|
||||||
|
/// vector is added to this list.
|
||||||
|
///
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Track {
|
||||||
|
pub track_id: i32,
|
||||||
|
pub time_since_update: i32,
|
||||||
|
pub features: Vec<Array1<f32>>,
|
||||||
|
|
||||||
|
covariance: Array2<f32>,
|
||||||
|
mean: Array1<f32>,
|
||||||
|
hits: i32,
|
||||||
|
age: i32,
|
||||||
|
state: TrackState,
|
||||||
|
n_init: i32,
|
||||||
|
max_age: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Track {
|
||||||
|
pub fn new(mean: Array1<f32>, covariance: Array2<f32>, track_id: i32, n_init: i32, max_age: i32, feature: Option<Array1<f32>>) -> Self {
|
||||||
|
Self {
|
||||||
|
track_id,
|
||||||
|
mean,
|
||||||
|
covariance,
|
||||||
|
hits: 1,
|
||||||
|
age: 1,
|
||||||
|
time_since_update: 0,
|
||||||
|
state: TrackState::Tentative,
|
||||||
|
features: feature.into_iter().collect(),
|
||||||
|
n_init,
|
||||||
|
max_age,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current position in bounding box format `(top left x, top left y, width, height)`.
|
||||||
|
///
|
||||||
|
/// Returns
|
||||||
|
/// -------
|
||||||
|
/// ndarray
|
||||||
|
/// The bounding box.
|
||||||
|
///
|
||||||
|
#[inline]
|
||||||
|
pub fn bbox(&self) -> BBox<Xyah> {
|
||||||
|
BBox::xyah(
|
||||||
|
self.mean[0],
|
||||||
|
self.mean[1],
|
||||||
|
self.mean[2],
|
||||||
|
self.mean[3],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn mean(&self) -> ArrayView1<'_, f32> {
|
||||||
|
self.mean.view()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn covariance(&self) -> ArrayView2<'_, f32> {
|
||||||
|
self.covariance.view()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Propagate the state distribution to the current time step using a
|
||||||
|
/// Kalman filter prediction step.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// kf : kalman_filter.KalmanFilter
|
||||||
|
/// The Kalman filter.
|
||||||
|
///
|
||||||
|
pub fn predict(&mut self, kf: &mut KalmanFilter) {
|
||||||
|
let (mean, covariance) = kf.predict(self.mean.view(), self.covariance.view());
|
||||||
|
self.mean = mean;
|
||||||
|
self.covariance = covariance;
|
||||||
|
self.age += 1;
|
||||||
|
self.time_since_update += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform Kalman filter measurement update step and update the feature cache.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// kf : kalman_filter.KalmanFilter
|
||||||
|
/// The Kalman filter.
|
||||||
|
/// detection : Detection
|
||||||
|
/// The associated detection.
|
||||||
|
///
|
||||||
|
pub fn update(&mut self, kf: &mut KalmanFilter, detection: &Detection) {
|
||||||
|
let (mean, covariance) = kf.update(
|
||||||
|
self.mean.view(),
|
||||||
|
self.covariance.view(),
|
||||||
|
detection.bbox.as_xyah().as_view()
|
||||||
|
);
|
||||||
|
|
||||||
|
self.mean = mean;
|
||||||
|
self.covariance = covariance;
|
||||||
|
self.features.extend(detection.feature.clone().into_iter());
|
||||||
|
|
||||||
|
self.hits += 1;
|
||||||
|
self.time_since_update = 0;
|
||||||
|
|
||||||
|
if self.state == TrackState::Tentative && self.hits >= self.n_init {
|
||||||
|
self.state = TrackState::Confirmed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Mark this track as missed (no association at the current time step).
|
||||||
|
///
|
||||||
|
#[inline]
|
||||||
|
pub fn mark_missed(&mut self) {
|
||||||
|
if self.state == TrackState::Tentative {
|
||||||
|
self.state = TrackState::Deleted;
|
||||||
|
} else if self.time_since_update > self.max_age {
|
||||||
|
self.state = TrackState::Deleted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
///
|
||||||
|
/// Returns True if this track is tentative (unconfirmed).
|
||||||
|
///
|
||||||
|
#[inline]
|
||||||
|
pub fn is_tentative(&self) -> bool {
|
||||||
|
self.state == TrackState::Tentative
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Returns True if this track is confirmed.
|
||||||
|
///
|
||||||
|
#[inline]
|
||||||
|
pub fn is_confirmed(&self) -> bool {
|
||||||
|
self.state == TrackState::Confirmed
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Returns True if this track is dead and should be deleted.
|
||||||
|
///
|
||||||
|
#[inline]
|
||||||
|
pub fn is_deleted(&self) -> bool {
|
||||||
|
self.state == TrackState::Deleted
|
||||||
|
}
|
||||||
|
}
|
210
src/sort/tracker.rs
Normal file
210
src/sort/tracker.rs
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
use ndarray::prelude::*;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use crate::sort::{Track, Detection, KalmanFilter, DistanceMetric};
|
||||||
|
|
||||||
|
/// This is the multi-target tracker.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// metric : nn_matching.NearestNeighborDistanceMetric
|
||||||
|
/// A distance metric for measurement-to-track association.
|
||||||
|
/// max_age : int
|
||||||
|
/// Maximum number of missed misses before a track is deleted.
|
||||||
|
/// n_init : int
|
||||||
|
/// Number of consecutive detections before the track is confirmed. The
|
||||||
|
/// track state is set to `Deleted` if a miss occurs within the first
|
||||||
|
/// `n_init` frames.
|
||||||
|
///
|
||||||
|
/// Attributes
|
||||||
|
/// ----------
|
||||||
|
/// metric : nn_matching.NearestNeighborDistanceMetric
|
||||||
|
/// The distance metric used for measurement to track association.
|
||||||
|
/// max_age : int
|
||||||
|
/// Maximum number of missed misses before a track is deleted.
|
||||||
|
/// n_init : int
|
||||||
|
/// Number of frames that a track remains in initialization phase.
|
||||||
|
/// kf : kalman_filter.KalmanFilter
|
||||||
|
/// A Kalman filter to filter target trajectories in image space.
|
||||||
|
/// tracks : List[Track]
|
||||||
|
/// The list of active tracks at the current time step.
|
||||||
|
///
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Tracker<M: DistanceMetric> {
|
||||||
|
metric: M,
|
||||||
|
max_iou_distance: f32,
|
||||||
|
max_age: i32,
|
||||||
|
n_init: i32,
|
||||||
|
kf: KalmanFilter,
|
||||||
|
tracks: Vec<Track>,
|
||||||
|
next_id: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: DistanceMetric> Tracker<M> {
|
||||||
|
pub fn new(metric: M, max_iou_distance: f32 /*=0.7*/, max_age: i32/*=70*/, n_init: i32/*=3*/) -> Self {
|
||||||
|
Self {
|
||||||
|
metric,
|
||||||
|
max_iou_distance,
|
||||||
|
max_age,
|
||||||
|
n_init,
|
||||||
|
kf: Default::default(),
|
||||||
|
next_id: 1,
|
||||||
|
tracks: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn tracks(&self) -> &[Track] {
|
||||||
|
self.tracks.as_slice()
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Propagate track state distributions one time step forward.
|
||||||
|
///
|
||||||
|
/// This function should be called once every time step, before `update`.
|
||||||
|
///
|
||||||
|
pub fn predict(&mut self) {
|
||||||
|
for track in &mut self.tracks {
|
||||||
|
track.predict(&mut self.kf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform measurement update and track management.
|
||||||
|
///
|
||||||
|
/// Parameters
|
||||||
|
/// ----------
|
||||||
|
/// detections : List[deep_sort.detection.Detection]
|
||||||
|
/// A list of detections at the current time step.
|
||||||
|
///
|
||||||
|
pub fn update(&mut self, detections: &[Detection]) {
|
||||||
|
let (matches, unmatched_tracks, unmatched_detections) = self.do_match(&detections);
|
||||||
|
|
||||||
|
for (track_idx, detection_idx) in matches {
|
||||||
|
self.tracks[track_idx].update(
|
||||||
|
&mut self.kf, &detections[detection_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for track_idx in unmatched_tracks {
|
||||||
|
self.tracks[track_idx].mark_missed();
|
||||||
|
}
|
||||||
|
|
||||||
|
for detection_idx in unmatched_detections {
|
||||||
|
self.initiate_track(&detections[detection_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.tracks.retain(|t|!t.is_deleted());
|
||||||
|
|
||||||
|
let (
|
||||||
|
mut features,
|
||||||
|
mut targets,
|
||||||
|
mut active_targets) = (vec![], vec![], vec![]);
|
||||||
|
|
||||||
|
for track in &mut self.tracks {
|
||||||
|
if !track.is_confirmed() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
active_targets.push(track.track_id);
|
||||||
|
|
||||||
|
for feature in track.features.iter() {
|
||||||
|
targets.push(track.track_id);
|
||||||
|
features.push(feature.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.metric.partial_fit(
|
||||||
|
features,
|
||||||
|
targets,
|
||||||
|
active_targets
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_match(&mut self, detections: &[Detection]) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
|
||||||
|
let gated_metric = |tracks: &[Track], dets: &[Detection], track_indices: &[usize], detection_indices: &[usize]| {
|
||||||
|
let mut features = unsafe { Array2::uninitialized((detection_indices.len(), dets[0].feature.as_ref().unwrap().len())) };
|
||||||
|
|
||||||
|
for (idx, mut axis) in features.axis_iter_mut(Axis(0)).enumerate() {
|
||||||
|
let index = dets[detection_indices[idx]].feature.as_ref().unwrap();
|
||||||
|
axis.assign(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
let targets: Vec<_> = track_indices.iter().map(|&i|tracks[i].track_id).collect();
|
||||||
|
let mut cost_matrix = self.metric.distance(features.view(), targets);
|
||||||
|
|
||||||
|
crate::sort::linear_assignment::gate_cost_matrix(
|
||||||
|
&self.kf,
|
||||||
|
cost_matrix.view_mut(),
|
||||||
|
&tracks,
|
||||||
|
&dets,
|
||||||
|
&track_indices,
|
||||||
|
detection_indices,
|
||||||
|
None,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
cost_matrix
|
||||||
|
};
|
||||||
|
|
||||||
|
// Split track set into confirmed and unconfirmed tracks.
|
||||||
|
let (mut confirmed_tracks, mut unconfirmed_tracks) = (vec![], vec![]);
|
||||||
|
for (i, t) in self.tracks.iter().enumerate() {
|
||||||
|
if t.is_confirmed() {
|
||||||
|
confirmed_tracks.push(i);
|
||||||
|
} else {
|
||||||
|
unconfirmed_tracks.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let matching_threshold = self.metric.matching_threshold();
|
||||||
|
let max_age = self.max_age;
|
||||||
|
|
||||||
|
// Associate confirmed tracks using appearance features.
|
||||||
|
let (matches_a, unmatched_tracks_a, unmatched_detections) =
|
||||||
|
crate::sort::linear_assignment::matching_cascade(
|
||||||
|
&gated_metric,
|
||||||
|
matching_threshold,
|
||||||
|
max_age,
|
||||||
|
&self.tracks,
|
||||||
|
detections,
|
||||||
|
Some(confirmed_tracks),
|
||||||
|
None);
|
||||||
|
|
||||||
|
// Associate remaining tracks together with unconfirmed tracks using IOU.
|
||||||
|
let (iou_track_candidates, unmatched_tracks_a): (Vec<_>, Vec<_>) = unmatched_tracks_a
|
||||||
|
.into_iter()
|
||||||
|
.partition(|&k|self.tracks[k].time_since_update == 1);
|
||||||
|
|
||||||
|
let iou_track_candidates = [unconfirmed_tracks.as_slice(), iou_track_candidates.as_slice()].concat();
|
||||||
|
|
||||||
|
let (matches_b, unmatched_tracks_b, unmatched_detections) =
|
||||||
|
crate::sort::linear_assignment::min_cost_matching(
|
||||||
|
&crate::sort::iou_matching::iou_cost,
|
||||||
|
self.max_iou_distance,
|
||||||
|
&self.tracks,
|
||||||
|
&detections,
|
||||||
|
Some(iou_track_candidates),
|
||||||
|
Some(unmatched_detections));
|
||||||
|
|
||||||
|
let matches = [matches_a, matches_b].concat();
|
||||||
|
let unmatched_tracks: HashSet<_> = unmatched_tracks_a
|
||||||
|
.into_iter()
|
||||||
|
.chain(unmatched_tracks_b.into_iter())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
(matches, unmatched_tracks.into_iter().collect(), unmatched_detections)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initiate_track(&mut self, detection: &Detection) {
|
||||||
|
let (mean, covariance) = self.kf.initiate(detection.bbox.as_xyah());
|
||||||
|
|
||||||
|
self.tracks.push(Track::new(
|
||||||
|
mean,
|
||||||
|
covariance,
|
||||||
|
self.next_id,
|
||||||
|
self.n_init,
|
||||||
|
self.max_age,
|
||||||
|
detection.feature.clone()
|
||||||
|
));
|
||||||
|
|
||||||
|
self.next_id += 1;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user