sauron/
yolo.rs

1//! Utility functions for a yolo model.
2
3use image::{DynamicImage, GenericImageView, imageops::FilterType};
4use ndarray::{Array, Axis, s};
5use ort::inputs;
6use ort::session::builder::GraphOptimizationLevel;
7use ort::session::{Session, SessionOutputs};
8use std::env;
9use std::fmt::Display;
10use std::str::FromStr;
11
12use types::cv::{BoxDetection, Detections, MatInfo, Model};
13
14/// Fetches an environment variable, or returns a default value if it is not set.
15fn get_env<T: FromStr + Display>(key: &str, default: T) -> T {
16    env::var(key)
17        .unwrap_or(format!("{}", default).to_string())
18        .to_string()
19        .parse()
20        .unwrap_or(default)
21}
22
23/// Loads a YOLO model from the path specified by the environment variable `SAURON_MODEL_PATH`.
24pub fn load_model() -> anyhow::Result<Model> {
25    let model = Session::builder()?
26        .with_optimization_level(GraphOptimizationLevel::Level3)?
27        .with_intra_threads(4)?
28        .commit_from_file(get_env(
29            "SAURON_MODEL_PATH",
30            "./crates/core/models/yolo11n.onnx".to_string(),
31        ))?;
32
33    let input_size = get_env("SAURON_INPUT_SIZE", 640_i32);
34
35    Ok(Model { model, input_size })
36}
37
38/// Pre-processes an image for YOLO detection.
39fn pre_process(model_data: &Model, img: &DynamicImage) -> anyhow::Result<DynamicImage> {
40    // Get the width and height of the image
41    // let width = img.width();
42    // let height = img.height();
43
44    // TODO tile photos
45
46    // Resize the image to 640 x 640 pxels, as that's what yolo uses
47    // https://legacy.imagemagick.org/Usage/filter/#windowed for more info on filtertype
48    // https://docs.rs/image/latest/image/imageops/enum.FilterType.html for the speed of each filter
49    let resized_image = img.resize_exact(
50        model_data.input_size as u32,
51        model_data.input_size as u32,
52        FilterType::CatmullRom,
53    );
54
55    Ok(resized_image)
56}
57
58/// Performs the actual detection and outputs the results.
59///
60/// # Arguments
61/// * `model_data` - The YOLO model.
62/// * `img` - The input image to be processed.
63pub async fn detect(model_data: &Model, img: &DynamicImage) -> anyhow::Result<Detections> {
64    let model = &model_data.model;
65
66    // Also used to be model_config
67    // Any reference to sauron_config used to be model_config
68
69    let mat_info = MatInfo {
70        width: img.width() as f32,
71        height: img.height() as f32,
72        scaled_size: model_data.input_size as f32,
73    };
74
75    let resized_mat = pre_process(model_data, img).unwrap();
76
77    // Create a blank input tensor
78    // shape: (batch, RGB channels, height, weight).
79    let usize_model_size = model_data.input_size as usize;
80    let mut input = Array::zeros((1, 3, usize_model_size, usize_model_size));
81
82    // Populate the tensor.
83    // Tensor is 4d array. [[batch, color, height, width]]
84    // height then width because of yolo Row-Major order
85    for pixel in resized_mat.pixels() {
86        let x = pixel.0 as _;
87        let y = pixel.1 as _;
88        let [r, g, b, _] = pixel.2.0;
89        input[[0, 0, y, x]] = (r as f32) / 255.;
90        input[[0, 1, y, x]] = (g as f32) / 255.;
91        input[[0, 2, y, x]] = (b as f32) / 255.;
92    }
93
94    // Convert ndarray array to Value
95    let outputs: SessionOutputs<'_, '_> = model.run(inputs!["images" => input.view()]?)?;
96
97    // Post Process the outputs to resize everything back to the normal image
98    let detections = post_process(&outputs, &mat_info)?;
99
100    Ok(detections)
101}
102
103/// Finds the overlapping area between two bounding boxes. Returns a negative value if the boxes don't overlap.
104fn intersection(box1: &BoxDetection, box2: &BoxDetection) -> f32 {
105    (box1.x2.min(box2.x2) - box1.x1.max(box2.x1)) * (box1.y2.min(box2.y2) - box1.y1.max(box2.y1))
106}
107
108/// Finds the total area covered by both bounding boxes (But breaks if the boxes don't overlap).
109fn union(box1: &BoxDetection, box2: &BoxDetection) -> f32 {
110    ((box1.x2 - box1.x1) * (box1.y2 - box1.y1)) + ((box2.x2 - box2.x1) * (box2.y2 - box2.y1))
111        - intersection(box1, box2)
112}
113
114/// Performs non-maximum suppression on the detections.
115fn post_process(
116    outputs: &SessionOutputs<'_, '_>,
117    mat_info: &MatInfo,
118) -> anyhow::Result<Detections> {
119    // Convert the outputs vector into an ArrayBase to "easier use" (idk, I copied this from a repo)
120    let output = outputs["output0"]
121        .try_extract_tensor::<f32>()?
122        .t()
123        .into_owned();
124
125    // Calculate the bounding box locations in reference to the orginial image
126    // and save them
127    let mut boxes = Vec::new();
128    let output = output.slice(s![.., .., 0]);
129    for row in output.axis_iter(Axis(0)) {
130        let row: Vec<_> = row.iter().copied().collect();
131        let (class_id, prob) = row
132            .iter()
133            // skip bounding box coordinates
134            .skip(4)
135            .enumerate()
136            .map(|(index, value)| (index, *value))
137            .reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
138            .unwrap();
139
140        // If the probability is too low, ignore that box
141        if prob < 0.5 {
142            continue;
143        }
144
145        // CLONED THE LABEL. Hopefully doesn't cause memory issues
146        // It causes memory issues, (aka the struct is wacky). so we no lonnger get the string name
147        // let label = model_config.class_names[class_id].clone();
148        let xc = row[0] / mat_info.scaled_size * mat_info.width;
149        let yc = row[1] / mat_info.scaled_size * mat_info.height;
150        let w = row[2] / mat_info.scaled_size * mat_info.width;
151        let h = row[3] / mat_info.scaled_size * mat_info.height;
152        boxes.push(BoxDetection {
153            x1: xc - w / 2.,
154            y1: yc - h / 2.,
155            x2: xc + w / 2.,
156            y2: yc + h / 2.,
157            class_index: class_id as i32,
158            conf: prob,
159        });
160    }
161
162    // Run Non-Maximum Suppression (NMS) on the detection boxes
163    boxes.sort_by(|box1, box2| box2.conf.total_cmp(&box1.conf));
164    let mut result = Vec::new();
165
166    while !boxes.is_empty() {
167        result.push(boxes[0]);
168        boxes = boxes
169            .iter()
170            .filter(|box1| intersection(&boxes[0], box1) / union(&boxes[0], box1) < 0.7)
171            .copied()
172            .collect();
173    }
174
175    Ok(result)
176}