1use image::{DynamicImage, GenericImageView, imageops::FilterType};
4use ndarray::{Array, Axis, s};
5use ort::inputs;
6use ort::session::builder::GraphOptimizationLevel;
7use ort::session::{Session, SessionOutputs};
8use std::env;
9use std::fmt::Display;
10use std::str::FromStr;
11
12use types::cv::{BoxDetection, Detections, MatInfo, Model};
13
14fn get_env<T: FromStr + Display>(key: &str, default: T) -> T {
16 env::var(key)
17 .unwrap_or(format!("{}", default).to_string())
18 .to_string()
19 .parse()
20 .unwrap_or(default)
21}
22
23pub fn load_model() -> anyhow::Result<Model> {
25 let model = Session::builder()?
26 .with_optimization_level(GraphOptimizationLevel::Level3)?
27 .with_intra_threads(4)?
28 .commit_from_file(get_env(
29 "SAURON_MODEL_PATH",
30 "./crates/core/models/yolo11n.onnx".to_string(),
31 ))?;
32
33 let input_size = get_env("SAURON_INPUT_SIZE", 640_i32);
34
35 Ok(Model { model, input_size })
36}
37
38fn pre_process(model_data: &Model, img: &DynamicImage) -> anyhow::Result<DynamicImage> {
40 let resized_image = img.resize_exact(
50 model_data.input_size as u32,
51 model_data.input_size as u32,
52 FilterType::CatmullRom,
53 );
54
55 Ok(resized_image)
56}
57
58pub async fn detect(model_data: &Model, img: &DynamicImage) -> anyhow::Result<Detections> {
64 let model = &model_data.model;
65
66 let mat_info = MatInfo {
70 width: img.width() as f32,
71 height: img.height() as f32,
72 scaled_size: model_data.input_size as f32,
73 };
74
75 let resized_mat = pre_process(model_data, img).unwrap();
76
77 let usize_model_size = model_data.input_size as usize;
80 let mut input = Array::zeros((1, 3, usize_model_size, usize_model_size));
81
82 for pixel in resized_mat.pixels() {
86 let x = pixel.0 as _;
87 let y = pixel.1 as _;
88 let [r, g, b, _] = pixel.2.0;
89 input[[0, 0, y, x]] = (r as f32) / 255.;
90 input[[0, 1, y, x]] = (g as f32) / 255.;
91 input[[0, 2, y, x]] = (b as f32) / 255.;
92 }
93
94 let outputs: SessionOutputs<'_, '_> = model.run(inputs!["images" => input.view()]?)?;
96
97 let detections = post_process(&outputs, &mat_info)?;
99
100 Ok(detections)
101}
102
103fn intersection(box1: &BoxDetection, box2: &BoxDetection) -> f32 {
105 (box1.x2.min(box2.x2) - box1.x1.max(box2.x1)) * (box1.y2.min(box2.y2) - box1.y1.max(box2.y1))
106}
107
108fn union(box1: &BoxDetection, box2: &BoxDetection) -> f32 {
110 ((box1.x2 - box1.x1) * (box1.y2 - box1.y1)) + ((box2.x2 - box2.x1) * (box2.y2 - box2.y1))
111 - intersection(box1, box2)
112}
113
114fn post_process(
116 outputs: &SessionOutputs<'_, '_>,
117 mat_info: &MatInfo,
118) -> anyhow::Result<Detections> {
119 let output = outputs["output0"]
121 .try_extract_tensor::<f32>()?
122 .t()
123 .into_owned();
124
125 let mut boxes = Vec::new();
128 let output = output.slice(s![.., .., 0]);
129 for row in output.axis_iter(Axis(0)) {
130 let row: Vec<_> = row.iter().copied().collect();
131 let (class_id, prob) = row
132 .iter()
133 .skip(4)
135 .enumerate()
136 .map(|(index, value)| (index, *value))
137 .reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
138 .unwrap();
139
140 if prob < 0.5 {
142 continue;
143 }
144
145 let xc = row[0] / mat_info.scaled_size * mat_info.width;
149 let yc = row[1] / mat_info.scaled_size * mat_info.height;
150 let w = row[2] / mat_info.scaled_size * mat_info.width;
151 let h = row[3] / mat_info.scaled_size * mat_info.height;
152 boxes.push(BoxDetection {
153 x1: xc - w / 2.,
154 y1: yc - h / 2.,
155 x2: xc + w / 2.,
156 y2: yc + h / 2.,
157 class_index: class_id as i32,
158 conf: prob,
159 });
160 }
161
162 boxes.sort_by(|box1, box2| box2.conf.total_cmp(&box1.conf));
164 let mut result = Vec::new();
165
166 while !boxes.is_empty() {
167 result.push(boxes[0]);
168 boxes = boxes
169 .iter()
170 .filter(|box1| intersection(&boxes[0], box1) / union(&boxes[0], box1) < 0.7)
171 .copied()
172 .collect();
173 }
174
175 Ok(result)
176}