core/
database.rs

1//! Utility functions for managing the PostgreSQL database.
2
3use postgresql_embedded::{PostgreSQL, Settings};
4use sqlx::PgPool;
5use std::time::Duration;
6use testcontainers::{
7    ContainerAsync, GenericImage, ImageExt,
8    core::{IntoContainerPort, WaitFor},
9    runners::AsyncRunner,
10};
11use tracing::info;
12
13use types::{
14    cv::Object,
15    db::{GPSRow, ObjectRow},
16};
17
18use crate::utils::get_env;
19
20/// Sets up the tables of the database.
21pub async fn create_schema(pool: &PgPool) -> anyhow::Result<()> {
22    sqlx::query(
23        r#"
24CREATE TABLE IF NOT EXISTS gps (
25    timestamp BIGINT PRIMARY KEY CHECK (timestamp >= 0),
26    lat REAL NOT NULL,
27    long REAL NOT NULL,
28    alt REAL NOT NULL,
29    heading REAL NOT NULL
30);
31"#,
32    )
33    .execute(pool)
34    .await?;
35
36    sqlx::query(
37        r#"
38CREATE TABLE IF NOT EXISTS speed (
39    timestamp INTEGER PRIMARY KEY CHECK (timestamp >= 0),
40    ground_speed REAL NOT NULL,
41    vertical_speed REAL NOT NULL,
42    air_speed REAL NOT NULL
43);
44"#,
45    )
46    .execute(pool)
47    .await?;
48
49    sqlx::query(
50        r#"
51CREATE TABLE IF NOT EXISTS images (
52    filepath TEXT PRIMARY KEY,
53    timestamp BIGINT NOT NULL CHECK (timestamp >= 0)
54);
55"#,
56    )
57    .execute(pool)
58    .await?;
59
60    sqlx::query(
61        r#"
62CREATE TABLE IF NOT EXISTS geotags (
63    filepath TEXT PRIMARY KEY REFERENCES images (filepath) ON UPDATE CASCADE ON DELETE CASCADE,
64    image_timestamp BIGINT NOT NULL CHECK (image_timestamp >= 0),
65    gps_time BIGINT NOT NULL REFERENCES gps (timestamp) ON UPDATE CASCADE ON DELETE RESTRICT,
66    lat REAL NOT NULL,
67    long REAL NOT NULL,
68    alt REAL NOT NULL,
69    heading REAL NOT NULL
70);
71"#,
72    )
73    .execute(pool)
74    .await?;
75
76    sqlx::query(
77        r#"
78CREATE TABLE IF NOT EXISTS objects (
79    id INTEGER PRIMARY KEY,
80    original_filepath TEXT NOT NULL,
81    lat REAL NOT NULL,
82    long REAL NOT NULL,
83    class INTEGER NOT NULL,
84    max_confidence REAL NOT NULL,
85    num_detections INTEGER NOT NULL,
86    UNIQUE(lat, long, class)
87);
88"#,
89    )
90    .execute(pool)
91    .await?;
92    Ok(())
93}
94
95/// Returns the closest 2 timestamps by their distance to the camera timestamp.
96pub async fn get_gps(pool: &PgPool, timestamp: &i64) -> anyhow::Result<(GPSRow, GPSRow)> {
97    let row = sqlx::query_as::<_, GPSRow>("SELECT * FROM gps ORDER BY ABS(timestamp - ?) LIMIT 2")
98        .bind(timestamp)
99        .fetch_all(pool)
100        .await?;
101
102    Ok((row[0].clone(), row[1].clone()))
103}
104
105/// Insert image file_path and timestamp into database (currently unused).
106pub async fn insert_image(
107    pool: &PgPool,
108    file_path: &String,
109    timestamp: &i64,
110) -> Result<(), sqlx::Error> {
111    sqlx::query("INSERT INTO images (filepath, timestamp) VALUES ($1, $2);")
112        .bind(file_path)
113        .bind(timestamp)
114        .execute(pool)
115        .await?;
116    Ok(())
117}
118
119/// Inserts all objects into the database, or updates them if they already exist.
120async fn bulk_upsert_objects(pool: &sqlx::PgPool, objects: &[ObjectRow]) -> anyhow::Result<()> {
121    // deconstruct the objects to insert
122    let mut lats = Vec::with_capacity(objects.len());
123    let mut longs = Vec::with_capacity(objects.len());
124    let mut classes = Vec::with_capacity(objects.len());
125    let mut max_confidences = Vec::with_capacity(objects.len());
126    let mut num_detections = Vec::with_capacity(objects.len());
127    let mut original_filepaths = Vec::with_capacity(objects.len());
128
129    for obj in objects {
130        lats.push(obj.lat);
131        longs.push(obj.long);
132        classes.push(obj.class);
133        max_confidences.push(obj.max_confidence);
134        num_detections.push(obj.num_detections);
135        original_filepaths.push(&obj.original_filepath);
136    }
137
138    // bulk upsert query. (either replace if it exists, or insert if not)
139    sqlx::query(
140        r#"
141            INSERT INTO objects (lat, long, class, max_confidence, num_detections, original_filepath)
142            SELECT * FROM UNNEST(
143                $1::real[],
144                $2::real[],
145                $3::integer[],
146                $4::real[],
147                $5::integer[],
148                $6::text[]
149            )
150            ON CONFLICT (lat, long, class) DO UPDATE
151            SET
152                max_confidence = GREATEST(objects.max_confidence, excluded.max_confidence),
153                num_detections = objects.num_detections + excluded.num_detections
154        "#,
155    )
156    .bind(lats)
157    .bind(longs)
158    .bind(classes)
159    .bind(max_confidences)
160    .bind(num_detections)
161    .bind(original_filepaths)
162    .execute(pool)
163    .await?;
164
165    Ok(())
166}
167
168/// Updates the object table in the database with a dynamic object matching strategy.
169///
170/// Currently uses a basic clustering algorithm which just groups objects within a certain distance.
171///
172/// # Arguments
173///
174/// * `pool` - A reference to the PostgreSQL connection pool
175/// * `objects` - A vector of `Object` instances to be updated in the database
176/// * `file_path` - The filepath of the image which the objects were localized from
177pub async fn smart_update_database(
178    pool: &PgPool,
179    objects: Vec<Object>,
180    file_path: &String,
181) -> anyhow::Result<()> {
182    // detects similarly placed objects (withing 11m currently) and then tracks them as an object in the database increasing label confidence and count
183    let current_objects = sqlx::query_as::<_, ObjectRow>("SELECT * from objects")
184        .fetch_all(pool)
185        .await?;
186
187    let mut new_objects = Vec::<ObjectRow>::new();
188
189    let object_tolerance = get_env("OBJECT_CORD_TOLERANCE", 0.0001_f64).abs();
190
191    for object in objects.iter() {
192        let mut found = false;
193        for current_object in current_objects.iter() {
194            if (current_object.lat - object.lat).abs() < object_tolerance
195                && (current_object.long - object.long).abs() < object_tolerance
196                && object.class == current_object.class
197            {
198                // TODO Create other object detection clustering algoirthms (might have to rework the current db to also store all detection if a clustering alg is used or similar).
199                new_objects.push(ObjectRow {
200                    num_detections: current_object.num_detections + 1,
201                    max_confidence: f32::max(current_object.max_confidence, object.confidence),
202                    ..(*current_object).clone()
203                });
204                found = true;
205            }
206        }
207        if !found {
208            new_objects.push(ObjectRow {
209                lat: object.lat,
210                long: object.long,
211                class: object.class,
212                max_confidence: object.confidence,
213                num_detections: 1,
214                original_filepath: file_path.clone(),
215            });
216        }
217    }
218
219    bulk_upsert_objects(pool, &new_objects).await?;
220
221    Ok(())
222}
223
224/// Starts a PostgreSQL container using Docker.
225///
226/// Returns a container instance that can be used to interact with the PostgreSQL instance.
227pub async fn start_postgres_container() -> anyhow::Result<ContainerAsync<GenericImage>> {
228    Ok(GenericImage::new("postgres", "latest")
229        .with_wait_for(WaitFor::Duration {
230            length: Duration::new(4, 0), // ideally fix hardcoded value, but not nessecary since this is for dev
231        })
232        .with_wait_for(WaitFor::message_on_stdout(
233            "database system is ready to accept connections",
234        ))
235        .with_mapped_port(5432, 5432.tcp()) // 5432 is default postgres port
236        .with_network("bridge")
237        .with_env_var("POSTGRES_DB", "local")
238        .with_env_var("POSTGRES_USER", "user")
239        .with_env_var("POSTGRES_PASSWORD", "password")
240        .start()
241        .await?)
242}
243
244/// Starts a PostgreSQL instance embedded in the application.
245///
246/// Returns a tuple containing the database settings and the URL of the PostgreSQL instance.
247pub async fn start_postgres_embedded() -> anyhow::Result<(PostgreSQL, String)> {
248    let settings = Settings::new(); //from_url(env::var("DATABASE_URL").unwrap_or("postgresql://user:password@localhost:5432/local".to_string())).expect("Failed to extract settings from DATABASE_URL (Incorrectly formatted url)"); Theoretically should work so we don't have to overwrite database_url but credentials don't set right
249
250    info!("{settings:?}");
251    let mut embedded_postgresql = PostgreSQL::new(settings);
252
253    embedded_postgresql.setup().await?;
254    embedded_postgresql.start().await?;
255    let database_name = "local";
256    embedded_postgresql.create_database(database_name).await?;
257    let url = embedded_postgresql.settings().url(database_name);
258
259    info!("Required connection string: {url}");
260
261    Ok((embedded_postgresql, url))
262}