core/
database.rs

1//! Utility functions for managing the PostgreSQL database.
2
3use postgresql_embedded::{PostgreSQL, Settings};
4use sqlx::PgPool;
5use std::time::Duration;
6use testcontainers::{
7    ContainerAsync, GenericImage, ImageExt,
8    core::{IntoContainerPort, WaitFor},
9    runners::AsyncRunner,
10};
11use tracing::info;
12
13use types::{
14    cv::Object,
15    db::{GPSRow, ObjectRow},
16};
17
18use crate::utils::get_env;
19
20/// Sets up the tables of the database.
21pub async fn create_schema(pool: &PgPool) -> anyhow::Result<()> {
22    sqlx::query(
23        r#"
24CREATE TABLE IF NOT EXISTS gps (
25    timestamp INTEGER PRIMARY KEY CHECK (timestamp >= 0),
26    lat REAL NOT NULL,
27    long REAL NOT NULL,
28    alt REAL NOT NULL,
29    heading REAL NOT NULL
30);
31"#,
32    )
33    .execute(pool)
34    .await?;
35
36    sqlx::query(
37        r#"
38CREATE TABLE IF NOT EXISTS images (
39    filepath TEXT PRIMARY KEY,
40    timestamp INTEGER NOT NULL CHECK (timestamp >= 0)
41);
42"#,
43    )
44    .execute(pool)
45    .await?;
46
47    sqlx::query(
48        r#"
49CREATE TABLE IF NOT EXISTS geotags (
50    filepath TEXT PRIMARY KEY REFERENCES images (filepath) ON UPDATE CASCADE ON DELETE CASCADE,
51    image_timestamp INTEGER NOT NULL CHECK (image_timestamp >= 0),
52    gps_time INTEGER NOT NULL REFERENCES gps (timestamp) ON UPDATE CASCADE ON DELETE RESTRICT,
53    lat REAL NOT NULL,
54    long REAL NOT NULL,
55    alt REAL NOT NULL,
56    heading REAL NOT NULL
57);
58"#,
59    )
60    .execute(pool)
61    .await?;
62
63    sqlx::query(
64        r#"
65CREATE TABLE IF NOT EXISTS objects (
66    id INTEGER PRIMARY KEY,
67    original_filepath TEXT NOT NULL,
68    lat REAL NOT NULL,
69    long REAL NOT NULL,
70    class INTEGER NOT NULL,
71    max_confidence REAL NOT NULL,
72    num_detections INTEGER NOT NULL,
73    UNIQUE(lat, long, class)
74);
75"#,
76    )
77    .execute(pool)
78    .await?;
79    Ok(())
80}
81
82/// Returns the closest 2 timestamps by their distance to the camera timestamp.
83pub async fn get_gps(pool: &PgPool, timestamp: &i64) -> anyhow::Result<(GPSRow, GPSRow)> {
84    let row = sqlx::query_as::<_, GPSRow>("SELECT * FROM gps ORDER BY ABS(timestamp - ?) LIMIT 2")
85        .bind(timestamp)
86        .fetch_all(pool)
87        .await?;
88
89    Ok((row[0].clone(), row[1].clone()))
90}
91
92/// Insert image file_path and timestamp into database (currently unused).
93pub async fn insert_image(
94    pool: &PgPool,
95    file_path: &String,
96    timestamp: &i64,
97) -> Result<(), sqlx::Error> {
98    sqlx::query("INSERT INTO images (filepath, timestamp) VALUES (?, ?);")
99        .bind(file_path)
100        .bind(timestamp)
101        .execute(pool)
102        .await?;
103    Ok(())
104}
105
106/// Inserts all objects into the database, or updates them if they already exist.
107async fn bulk_upsert_objects(pool: &sqlx::PgPool, objects: &[ObjectRow]) -> anyhow::Result<()> {
108    // deconstruct the objects to insert
109    let mut lats = Vec::with_capacity(objects.len());
110    let mut longs = Vec::with_capacity(objects.len());
111    let mut classes = Vec::with_capacity(objects.len());
112    let mut max_confidences = Vec::with_capacity(objects.len());
113    let mut num_detections = Vec::with_capacity(objects.len());
114    let mut original_filepaths = Vec::with_capacity(objects.len());
115
116    for obj in objects {
117        lats.push(obj.lat);
118        longs.push(obj.long);
119        classes.push(obj.class);
120        max_confidences.push(obj.max_confidence);
121        num_detections.push(obj.num_detections);
122        original_filepaths.push(&obj.original_filepath);
123    }
124
125    // bulk upsert query. (either replace if it exists, or insert if not)
126    sqlx::query(
127        r#"
128            INSERT INTO objects (lat, long, class, max_confidence, num_detections, original_filepath)
129            SELECT * FROM UNNEST(
130                $1::real[],
131                $2::real[],
132                $3::integer[],
133                $4::real[],
134                $5::integer[],
135                &6::text[]
136            )
137            ON CONFLICT (lat, long, class) DO UPDATE
138            SET
139                max_confidence = excluded.max_confidence,
140                num_detections = objects.num_detections + 1;
141        "#,
142    )
143    .bind(lats)
144    .bind(longs)
145    .bind(classes)
146    .bind(max_confidences)
147    .bind(num_detections)
148    .bind(original_filepaths)
149    .execute(pool)
150    .await?;
151
152    Ok(())
153}
154
155/// Updates the object table in the database with a dynamic object matching strategy.
156///
157/// Currently uses a basic clustering algorithm which just groups objects within a certain distance.
158///
159/// # Arguments
160///
161/// * `pool` - A reference to the PostgreSQL connection pool
162/// * `objects` - A vector of `Object` instances to be updated in the database
163/// * `file_path` - The filepath of the image which the objects were localized from
164pub async fn smart_update_database(
165    pool: &PgPool,
166    objects: Vec<Object>,
167    file_path: &String,
168) -> anyhow::Result<()> {
169    // detects similarly placed objects (withing 11m currently) and then tracks them as an object in the database increasing label confidence and count
170    let current_objects = sqlx::query_as::<_, ObjectRow>("SELECT * from objects")
171        .fetch_all(pool)
172        .await?;
173
174    let mut new_objects = Vec::<ObjectRow>::new();
175
176    let object_tolerance = get_env("OBJECT_CORD_TOLERANCE", 0.0001_f64).abs();
177
178    for object in objects.iter() {
179        let mut found = false;
180        for current_object in current_objects.iter() {
181            if (current_object.lat - object.lat).abs() < object_tolerance
182                && (current_object.long - object.long).abs() < object_tolerance
183                && object.class == current_object.class
184            {
185                // TODO Create other object detection clustering algoirthms (might have to rework the current db to also store all detection if a clustering alg is used or similar).
186                new_objects.push(ObjectRow {
187                    num_detections: current_object.num_detections + 1,
188                    max_confidence: f32::max(current_object.max_confidence, object.confidence),
189                    ..(*current_object).clone()
190                });
191                found = true;
192            }
193        }
194        if !found {
195            new_objects.push(ObjectRow {
196                lat: object.lat,
197                long: object.long,
198                class: object.class,
199                max_confidence: object.confidence,
200                num_detections: 1,
201                original_filepath: file_path.clone(),
202            });
203        }
204    }
205
206    bulk_upsert_objects(pool, &new_objects).await?;
207
208    Ok(())
209}
210
211/// Starts a PostgreSQL container using Docker.
212///
213/// Returns a container instance that can be used to interact with the PostgreSQL instance.
214pub async fn start_postgres_container() -> anyhow::Result<ContainerAsync<GenericImage>> {
215    Ok(GenericImage::new("postgres", "latest")
216        .with_wait_for(WaitFor::Duration {
217            length: Duration::new(4, 0), // ideally fix hardcoded value, but not nessecary since this is for dev
218        })
219        .with_wait_for(WaitFor::message_on_stdout(
220            "database system is ready to accept connections",
221        ))
222        .with_mapped_port(5432, 5432.tcp()) // 5432 is default postgres port
223        .with_network("bridge")
224        .with_env_var("POSTGRES_DB", "local")
225        .with_env_var("POSTGRES_USER", "user")
226        .with_env_var("POSTGRES_PASSWORD", "password")
227        .start()
228        .await?)
229}
230
231/// Starts a PostgreSQL instance embedded in the application.
232///
233/// Returns a tuple containing the database settings and the URL of the PostgreSQL instance.
234pub async fn start_postgres_embedded() -> anyhow::Result<(PostgreSQL, String)> {
235    let settings = Settings::new(); //from_url(env::var("DATABASE_URL").unwrap_or("postgresql://user:password@localhost:5432/local".to_string())).expect("Failed to extract settings from DATABASE_URL (Incorrectly formatted url)"); Theoretically should work so we don't have to overwrite database_url but credentials don't set right
236
237    info!("{settings:?}");
238    let mut embedded_postgresql = PostgreSQL::new(settings);
239
240    embedded_postgresql.setup().await?;
241    embedded_postgresql.start().await?;
242    let database_name = "local";
243    embedded_postgresql.create_database(database_name).await?;
244    let url = embedded_postgresql.settings().url(database_name);
245
246    info!("Required connection string: {url}");
247
248    Ok((embedded_postgresql, url))
249}