1use postgresql_embedded::{PostgreSQL, Settings};
4use sqlx::PgPool;
5use std::time::Duration;
6use testcontainers::{
7 ContainerAsync, GenericImage, ImageExt,
8 core::{IntoContainerPort, WaitFor},
9 runners::AsyncRunner,
10};
11use tracing::info;
12
13use types::{
14 cv::Object,
15 db::{GPSRow, ObjectRow},
16};
17
18use crate::utils::get_env;
19
20pub async fn create_schema(pool: &PgPool) -> anyhow::Result<()> {
22 sqlx::query(
23 r#"
24CREATE TABLE IF NOT EXISTS gps (
25 timestamp BIGINT PRIMARY KEY CHECK (timestamp >= 0),
26 lat REAL NOT NULL,
27 long REAL NOT NULL,
28 alt REAL NOT NULL,
29 heading REAL NOT NULL
30);
31"#,
32 )
33 .execute(pool)
34 .await?;
35
36 sqlx::query(
37 r#"
38CREATE TABLE IF NOT EXISTS speed (
39 timestamp INTEGER PRIMARY KEY CHECK (timestamp >= 0),
40 ground_speed REAL NOT NULL,
41 vertical_speed REAL NOT NULL,
42 air_speed REAL NOT NULL
43);
44"#,
45 )
46 .execute(pool)
47 .await?;
48
49 sqlx::query(
50 r#"
51CREATE TABLE IF NOT EXISTS images (
52 filepath TEXT PRIMARY KEY,
53 timestamp BIGINT NOT NULL CHECK (timestamp >= 0)
54);
55"#,
56 )
57 .execute(pool)
58 .await?;
59
60 sqlx::query(
61 r#"
62CREATE TABLE IF NOT EXISTS geotags (
63 filepath TEXT PRIMARY KEY REFERENCES images (filepath) ON UPDATE CASCADE ON DELETE CASCADE,
64 image_timestamp BIGINT NOT NULL CHECK (image_timestamp >= 0),
65 gps_time BIGINT NOT NULL REFERENCES gps (timestamp) ON UPDATE CASCADE ON DELETE RESTRICT,
66 lat REAL NOT NULL,
67 long REAL NOT NULL,
68 alt REAL NOT NULL,
69 heading REAL NOT NULL
70);
71"#,
72 )
73 .execute(pool)
74 .await?;
75
76 sqlx::query(
77 r#"
78CREATE TABLE IF NOT EXISTS objects (
79 id INTEGER PRIMARY KEY,
80 original_filepath TEXT NOT NULL,
81 lat REAL NOT NULL,
82 long REAL NOT NULL,
83 class INTEGER NOT NULL,
84 max_confidence REAL NOT NULL,
85 num_detections INTEGER NOT NULL,
86 UNIQUE(lat, long, class)
87);
88"#,
89 )
90 .execute(pool)
91 .await?;
92 Ok(())
93}
94
95pub async fn get_gps(pool: &PgPool, timestamp: &i64) -> anyhow::Result<(GPSRow, GPSRow)> {
97 let row = sqlx::query_as::<_, GPSRow>("SELECT * FROM gps ORDER BY ABS(timestamp - ?) LIMIT 2")
98 .bind(timestamp)
99 .fetch_all(pool)
100 .await?;
101
102 Ok((row[0].clone(), row[1].clone()))
103}
104
105pub async fn insert_image(
107 pool: &PgPool,
108 file_path: &String,
109 timestamp: &i64,
110) -> Result<(), sqlx::Error> {
111 sqlx::query("INSERT INTO images (filepath, timestamp) VALUES ($1, $2);")
112 .bind(file_path)
113 .bind(timestamp)
114 .execute(pool)
115 .await?;
116 Ok(())
117}
118
119async fn bulk_upsert_objects(pool: &sqlx::PgPool, objects: &[ObjectRow]) -> anyhow::Result<()> {
121 let mut lats = Vec::with_capacity(objects.len());
123 let mut longs = Vec::with_capacity(objects.len());
124 let mut classes = Vec::with_capacity(objects.len());
125 let mut max_confidences = Vec::with_capacity(objects.len());
126 let mut num_detections = Vec::with_capacity(objects.len());
127 let mut original_filepaths = Vec::with_capacity(objects.len());
128
129 for obj in objects {
130 lats.push(obj.lat);
131 longs.push(obj.long);
132 classes.push(obj.class);
133 max_confidences.push(obj.max_confidence);
134 num_detections.push(obj.num_detections);
135 original_filepaths.push(&obj.original_filepath);
136 }
137
138 sqlx::query(
140 r#"
141 INSERT INTO objects (lat, long, class, max_confidence, num_detections, original_filepath)
142 SELECT * FROM UNNEST(
143 $1::real[],
144 $2::real[],
145 $3::integer[],
146 $4::real[],
147 $5::integer[],
148 $6::text[]
149 )
150 ON CONFLICT (lat, long, class) DO UPDATE
151 SET
152 max_confidence = GREATEST(objects.max_confidence, excluded.max_confidence),
153 num_detections = objects.num_detections + excluded.num_detections
154 "#,
155 )
156 .bind(lats)
157 .bind(longs)
158 .bind(classes)
159 .bind(max_confidences)
160 .bind(num_detections)
161 .bind(original_filepaths)
162 .execute(pool)
163 .await?;
164
165 Ok(())
166}
167
168pub async fn smart_update_database(
178 pool: &PgPool,
179 objects: Vec<Object>,
180 file_path: &String,
181) -> anyhow::Result<()> {
182 let current_objects = sqlx::query_as::<_, ObjectRow>("SELECT * from objects")
184 .fetch_all(pool)
185 .await?;
186
187 let mut new_objects = Vec::<ObjectRow>::new();
188
189 let object_tolerance = get_env("OBJECT_CORD_TOLERANCE", 0.0001_f64).abs();
190
191 for object in objects.iter() {
192 let mut found = false;
193 for current_object in current_objects.iter() {
194 if (current_object.lat - object.lat).abs() < object_tolerance
195 && (current_object.long - object.long).abs() < object_tolerance
196 && object.class == current_object.class
197 {
198 new_objects.push(ObjectRow {
200 num_detections: current_object.num_detections + 1,
201 max_confidence: f32::max(current_object.max_confidence, object.confidence),
202 ..(*current_object).clone()
203 });
204 found = true;
205 }
206 }
207 if !found {
208 new_objects.push(ObjectRow {
209 lat: object.lat,
210 long: object.long,
211 class: object.class,
212 max_confidence: object.confidence,
213 num_detections: 1,
214 original_filepath: file_path.clone(),
215 });
216 }
217 }
218
219 bulk_upsert_objects(pool, &new_objects).await?;
220
221 Ok(())
222}
223
224pub async fn start_postgres_container() -> anyhow::Result<ContainerAsync<GenericImage>> {
228 Ok(GenericImage::new("postgres", "latest")
229 .with_wait_for(WaitFor::Duration {
230 length: Duration::new(4, 0), })
232 .with_wait_for(WaitFor::message_on_stdout(
233 "database system is ready to accept connections",
234 ))
235 .with_mapped_port(5432, 5432.tcp()) .with_network("bridge")
237 .with_env_var("POSTGRES_DB", "local")
238 .with_env_var("POSTGRES_USER", "user")
239 .with_env_var("POSTGRES_PASSWORD", "password")
240 .start()
241 .await?)
242}
243
244pub async fn start_postgres_embedded() -> anyhow::Result<(PostgreSQL, String)> {
248 let settings = Settings::new(); info!("{settings:?}");
251 let mut embedded_postgresql = PostgreSQL::new(settings);
252
253 embedded_postgresql.setup().await?;
254 embedded_postgresql.start().await?;
255 let database_name = "local";
256 embedded_postgresql.create_database(database_name).await?;
257 let url = embedded_postgresql.settings().url(database_name);
258
259 info!("Required connection string: {url}");
260
261 Ok((embedded_postgresql, url))
262}