1use postgresql_embedded::{PostgreSQL, Settings};
4use sqlx::PgPool;
5use std::time::Duration;
6use testcontainers::{
7 ContainerAsync, GenericImage, ImageExt,
8 core::{IntoContainerPort, WaitFor},
9 runners::AsyncRunner,
10};
11use tracing::info;
12
13use types::{
14 cv::Object,
15 db::{GPSRow, ObjectRow},
16};
17
18use crate::utils::get_env;
19
20pub async fn create_schema(pool: &PgPool) -> anyhow::Result<()> {
22 sqlx::query(
23 r#"
24CREATE TABLE IF NOT EXISTS gps (
25 timestamp INTEGER PRIMARY KEY CHECK (timestamp >= 0),
26 lat REAL NOT NULL,
27 long REAL NOT NULL,
28 alt REAL NOT NULL,
29 heading REAL NOT NULL
30);
31"#,
32 )
33 .execute(pool)
34 .await?;
35
36 sqlx::query(
37 r#"
38CREATE TABLE IF NOT EXISTS images (
39 filepath TEXT PRIMARY KEY,
40 timestamp INTEGER NOT NULL CHECK (timestamp >= 0)
41);
42"#,
43 )
44 .execute(pool)
45 .await?;
46
47 sqlx::query(
48 r#"
49CREATE TABLE IF NOT EXISTS geotags (
50 filepath TEXT PRIMARY KEY REFERENCES images (filepath) ON UPDATE CASCADE ON DELETE CASCADE,
51 image_timestamp INTEGER NOT NULL CHECK (image_timestamp >= 0),
52 gps_time INTEGER NOT NULL REFERENCES gps (timestamp) ON UPDATE CASCADE ON DELETE RESTRICT,
53 lat REAL NOT NULL,
54 long REAL NOT NULL,
55 alt REAL NOT NULL,
56 heading REAL NOT NULL
57);
58"#,
59 )
60 .execute(pool)
61 .await?;
62
63 sqlx::query(
64 r#"
65CREATE TABLE IF NOT EXISTS objects (
66 id INTEGER PRIMARY KEY,
67 original_filepath TEXT NOT NULL,
68 lat REAL NOT NULL,
69 long REAL NOT NULL,
70 class INTEGER NOT NULL,
71 max_confidence REAL NOT NULL,
72 num_detections INTEGER NOT NULL,
73 UNIQUE(lat, long, class)
74);
75"#,
76 )
77 .execute(pool)
78 .await?;
79 Ok(())
80}
81
82pub async fn get_gps(pool: &PgPool, timestamp: &i64) -> anyhow::Result<(GPSRow, GPSRow)> {
84 let row = sqlx::query_as::<_, GPSRow>("SELECT * FROM gps ORDER BY ABS(timestamp - ?) LIMIT 2")
85 .bind(timestamp)
86 .fetch_all(pool)
87 .await?;
88
89 Ok((row[0].clone(), row[1].clone()))
90}
91
92pub async fn insert_image(
94 pool: &PgPool,
95 file_path: &String,
96 timestamp: &i64,
97) -> Result<(), sqlx::Error> {
98 sqlx::query("INSERT INTO images (filepath, timestamp) VALUES (?, ?);")
99 .bind(file_path)
100 .bind(timestamp)
101 .execute(pool)
102 .await?;
103 Ok(())
104}
105
106async fn bulk_upsert_objects(pool: &sqlx::PgPool, objects: &[ObjectRow]) -> anyhow::Result<()> {
108 let mut lats = Vec::with_capacity(objects.len());
110 let mut longs = Vec::with_capacity(objects.len());
111 let mut classes = Vec::with_capacity(objects.len());
112 let mut max_confidences = Vec::with_capacity(objects.len());
113 let mut num_detections = Vec::with_capacity(objects.len());
114 let mut original_filepaths = Vec::with_capacity(objects.len());
115
116 for obj in objects {
117 lats.push(obj.lat);
118 longs.push(obj.long);
119 classes.push(obj.class);
120 max_confidences.push(obj.max_confidence);
121 num_detections.push(obj.num_detections);
122 original_filepaths.push(&obj.original_filepath);
123 }
124
125 sqlx::query(
127 r#"
128 INSERT INTO objects (lat, long, class, max_confidence, num_detections, original_filepath)
129 SELECT * FROM UNNEST(
130 $1::real[],
131 $2::real[],
132 $3::integer[],
133 $4::real[],
134 $5::integer[],
135 &6::text[]
136 )
137 ON CONFLICT (lat, long, class) DO UPDATE
138 SET
139 max_confidence = excluded.max_confidence,
140 num_detections = objects.num_detections + 1;
141 "#,
142 )
143 .bind(lats)
144 .bind(longs)
145 .bind(classes)
146 .bind(max_confidences)
147 .bind(num_detections)
148 .bind(original_filepaths)
149 .execute(pool)
150 .await?;
151
152 Ok(())
153}
154
155pub async fn smart_update_database(
165 pool: &PgPool,
166 objects: Vec<Object>,
167 file_path: &String,
168) -> anyhow::Result<()> {
169 let current_objects = sqlx::query_as::<_, ObjectRow>("SELECT * from objects")
171 .fetch_all(pool)
172 .await?;
173
174 let mut new_objects = Vec::<ObjectRow>::new();
175
176 let object_tolerance = get_env("OBJECT_CORD_TOLERANCE", 0.0001_f64).abs();
177
178 for object in objects.iter() {
179 let mut found = false;
180 for current_object in current_objects.iter() {
181 if (current_object.lat - object.lat).abs() < object_tolerance
182 && (current_object.long - object.long).abs() < object_tolerance
183 && object.class == current_object.class
184 {
185 new_objects.push(ObjectRow {
187 num_detections: current_object.num_detections + 1,
188 max_confidence: f32::max(current_object.max_confidence, object.confidence),
189 ..(*current_object).clone()
190 });
191 found = true;
192 }
193 }
194 if !found {
195 new_objects.push(ObjectRow {
196 lat: object.lat,
197 long: object.long,
198 class: object.class,
199 max_confidence: object.confidence,
200 num_detections: 1,
201 original_filepath: file_path.clone(),
202 });
203 }
204 }
205
206 bulk_upsert_objects(pool, &new_objects).await?;
207
208 Ok(())
209}
210
211pub async fn start_postgres_container() -> anyhow::Result<ContainerAsync<GenericImage>> {
215 Ok(GenericImage::new("postgres", "latest")
216 .with_wait_for(WaitFor::Duration {
217 length: Duration::new(4, 0), })
219 .with_wait_for(WaitFor::message_on_stdout(
220 "database system is ready to accept connections",
221 ))
222 .with_mapped_port(5432, 5432.tcp()) .with_network("bridge")
224 .with_env_var("POSTGRES_DB", "local")
225 .with_env_var("POSTGRES_USER", "user")
226 .with_env_var("POSTGRES_PASSWORD", "password")
227 .start()
228 .await?)
229}
230
231pub async fn start_postgres_embedded() -> anyhow::Result<(PostgreSQL, String)> {
235 let settings = Settings::new(); info!("{settings:?}");
238 let mut embedded_postgresql = PostgreSQL::new(settings);
239
240 embedded_postgresql.setup().await?;
241 embedded_postgresql.start().await?;
242 let database_name = "local";
243 embedded_postgresql.create_database(database_name).await?;
244 let url = embedded_postgresql.settings().url(database_name);
245
246 info!("Required connection string: {url}");
247
248 Ok((embedded_postgresql, url))
249}