nautilus_testkit/
files.rs1use std::{
17 fs::{File, OpenOptions},
18 io::{BufReader, BufWriter, Read, copy},
19 path::Path,
20 thread::sleep,
21 time::Duration,
22};
23
24use aws_lc_rs::digest::{self, Context};
25use rand::{Rng, rng};
26use reqwest::blocking::Client;
27use serde_json::Value;
28
29pub fn ensure_file_exists_or_download_http(
46 filepath: &Path,
47 url: &str,
48 checksums: Option<&Path>,
49) -> anyhow::Result<()> {
50 if filepath.exists() {
51 println!("File already exists: {filepath:?}");
52
53 if let Some(checksums_file) = checksums {
54 if verify_sha256_checksum(filepath, checksums_file)? {
55 println!("File is valid");
56 return Ok(());
57 } else {
58 let new_checksum = calculate_sha256(filepath)?;
59 println!("Adding checksum for existing file: {new_checksum}");
60 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
61 return Ok(());
62 }
63 }
64 return Ok(());
65 }
66
67 let jitter_delay = {
71 let mut r = rng();
72 Duration::from_millis(r.random_range(100..=600))
73 };
74 sleep(jitter_delay);
75
76 download_file(filepath, url)?;
77
78 if let Some(checksums_file) = checksums {
79 let new_checksum = calculate_sha256(filepath)?;
80 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
81 }
82
83 Ok(())
84}
85
86fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
87 const MAX_RETRIES: usize = 3;
88 const BASE_DELAY_MS: u64 = 1000;
89 const TIMEOUT_SECONDS: u64 = 30;
90
91 println!("Downloading file from {url} to {filepath:?}");
92
93 if let Some(parent) = filepath.parent() {
94 std::fs::create_dir_all(parent)?;
95 }
96
97 let client = Client::builder()
98 .timeout(Duration::from_secs(TIMEOUT_SECONDS))
99 .build()?;
100
101 let mut last_error = None;
102
103 for attempt in 0..MAX_RETRIES {
104 if attempt > 0 {
105 let exponential_delay_ms = BASE_DELAY_MS * 2_u64.pow(attempt as u32 - 1);
107 let jitter_ms = rng().random_range(0..BASE_DELAY_MS);
108 let delay = Duration::from_millis(exponential_delay_ms + jitter_ms);
109 println!(
110 "Retrying download in {delay:?} (attempt {}/{MAX_RETRIES})",
111 attempt + 1
112 );
113 sleep(delay);
114 }
115
116 match client.get(url).send() {
117 Ok(mut response) => {
118 let status = response.status();
119 if status.is_success() {
120 let mut out = File::create(filepath)?;
121 copy(&mut response, &mut out)?;
123 println!("File downloaded to {filepath:?}");
124 return Ok(());
125 } else if status.is_server_error() {
126 println!("Server error (HTTP {status}), retrying...");
128 last_error = Some(anyhow::anyhow!("Server error: HTTP {status}"));
129 continue;
130 }
131 anyhow::bail!("Client error: HTTP {status}");
133 }
134 Err(e) => {
135 println!("Request failed: {e}");
136 last_error = Some(anyhow::anyhow!("Request failed: {e}"));
137 continue;
138 }
139 }
140 }
141
142 Err(last_error
143 .unwrap_or_else(|| anyhow::anyhow!("Download failed after {MAX_RETRIES} attempts")))
144}
145
146fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
147 let mut file = File::open(filepath)?;
148 let mut ctx = Context::new(&digest::SHA256);
149 let mut buffer = [0u8; 4096];
150
151 loop {
152 let count = file.read(&mut buffer)?;
153 if count == 0 {
154 break;
155 }
156 ctx.update(&buffer[..count]);
157 }
158
159 let digest = ctx.finish();
160 Ok(hex::encode(digest.as_ref()))
161}
162
163fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
164 let file = File::open(checksums)?;
165 let reader = BufReader::new(file);
166 let checksums: Value = serde_json::from_reader(reader)?;
167
168 let filename = filepath.file_name().unwrap().to_str().unwrap();
169 if let Some(expected_checksum) = checksums.get(filename) {
170 let expected_checksum_str = expected_checksum.as_str().unwrap();
171 let expected_hash = expected_checksum_str
172 .strip_prefix("sha256:")
173 .unwrap_or(expected_checksum_str);
174 let calculated_checksum = calculate_sha256(filepath)?;
175 if expected_hash == calculated_checksum {
176 return Ok(true);
177 }
178 }
179
180 Ok(false)
181}
182
183fn update_sha256_checksums(
184 filepath: &Path,
185 checksums_file: &Path,
186 new_checksum: &str,
187) -> anyhow::Result<()> {
188 let checksums: Value = if checksums_file.exists() {
189 let file = File::open(checksums_file)?;
190 let reader = BufReader::new(file);
191 serde_json::from_reader(reader)?
192 } else {
193 serde_json::json!({})
194 };
195
196 let mut checksums_map = checksums.as_object().unwrap().clone();
197
198 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
200 let prefixed_checksum = format!("sha256:{new_checksum}");
201 checksums_map.insert(filename, Value::String(prefixed_checksum));
202
203 let file = OpenOptions::new()
204 .write(true)
205 .create(true)
206 .truncate(true)
207 .open(checksums_file)?;
208 let writer = BufWriter::new(file);
209 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
210
211 Ok(())
212}
213
214#[cfg(test)]
218mod tests {
219 use std::{
220 fs,
221 io::{BufWriter, Write},
222 net::SocketAddr,
223 sync::Arc,
224 };
225
226 use axum::{Router, http::StatusCode, routing::get, serve};
227 use rstest::*;
228 use serde_json::{json, to_writer};
229 use tempfile::TempDir;
230 use tokio::{
231 net::TcpListener,
232 task,
233 time::{Duration, sleep},
234 };
235
236 use super::*;
237
238 async fn setup_test_server(
239 server_content: Option<String>,
240 status_code: StatusCode,
241 ) -> SocketAddr {
242 let server_content = Arc::new(server_content);
243 let server_content_clone = server_content.clone();
244 let app = Router::new().route(
245 "/testfile.txt",
246 get(move || {
247 let server_content = server_content_clone.clone();
248 async move {
249 let response_body = match &*server_content {
250 Some(content) => content.clone(),
251 None => "File not found".to_string(),
252 };
253 (status_code, response_body)
254 }
255 }),
256 );
257
258 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
259 let addr = listener.local_addr().unwrap();
260 let server = serve(listener, app);
261
262 task::spawn(async move {
263 if let Err(e) = server.await {
264 eprintln!("server error: {e}");
265 }
266 });
267
268 sleep(Duration::from_millis(100)).await;
269
270 addr
271 }
272
273 #[tokio::test]
274 async fn test_file_already_exists() {
275 let temp_dir = TempDir::new().unwrap();
276 let file_path = temp_dir.path().join("testfile.txt");
277 fs::write(&file_path, "Existing file content").unwrap();
278
279 let url = "http://example.com/testfile.txt".to_string();
280 let result = ensure_file_exists_or_download_http(&file_path, &url, None);
281
282 assert!(result.is_ok());
283 let content = fs::read_to_string(&file_path).unwrap();
284 assert_eq!(content, "Existing file content");
285 }
286
287 #[tokio::test]
288 async fn test_download_file_success() {
289 let temp_dir = TempDir::new().unwrap();
290 let filepath = temp_dir.path().join("testfile.txt");
291 let filepath_clone = filepath.clone();
292
293 let server_content = "Server file content".to_string();
294 let status_code = StatusCode::OK;
295 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
296 let url = format!("http://{addr}/testfile.txt");
297
298 let result = tokio::task::spawn_blocking(move || {
299 ensure_file_exists_or_download_http(&filepath_clone, &url, None)
300 })
301 .await
302 .unwrap();
303
304 assert!(result.is_ok());
305 let content = fs::read_to_string(&filepath).unwrap();
306 assert_eq!(content, server_content);
307 }
308
309 #[tokio::test]
310 async fn test_download_file_not_found() {
311 let temp_dir = TempDir::new().unwrap();
312 let file_path = temp_dir.path().join("testfile.txt");
313
314 let server_content = None;
315 let status_code = StatusCode::NOT_FOUND;
316 let addr = setup_test_server(server_content, status_code).await;
317 let url = format!("http://{addr}/testfile.txt");
318
319 let result = tokio::task::spawn_blocking(move || {
320 ensure_file_exists_or_download_http(&file_path, &url, None)
321 })
322 .await
323 .unwrap();
324
325 assert!(result.is_err());
326 let err_msg = format!("{}", result.unwrap_err());
327 assert!(
328 err_msg.contains("Client error: HTTP"),
329 "Unexpected error message: {err_msg}"
330 );
331 }
332
333 #[tokio::test]
334 async fn test_network_error() {
335 let temp_dir = TempDir::new().unwrap();
336 let file_path = temp_dir.path().join("testfile.txt");
337
338 let url = "http://127.0.0.1:0/testfile.txt".to_string();
340
341 let result = tokio::task::spawn_blocking(move || {
342 ensure_file_exists_or_download_http(&file_path, &url, None)
343 })
344 .await
345 .unwrap();
346
347 assert!(result.is_err());
348 let err_msg = format!("{}", result.unwrap_err());
349 assert!(
350 err_msg.contains("error"),
351 "Unexpected error message: {err_msg}"
352 );
353 }
354
355 #[rstest]
356 fn test_calculate_sha256() -> anyhow::Result<()> {
357 let temp_dir = TempDir::new()?;
358 let test_file_path = temp_dir.path().join("test_file.txt");
359 let mut test_file = File::create(&test_file_path)?;
360 let content = b"Hello, world!";
361 test_file.write_all(content)?;
362
363 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
364 let calculated_hash = calculate_sha256(&test_file_path)?;
365
366 assert_eq!(calculated_hash, expected_hash);
367 Ok(())
368 }
369
370 #[rstest]
371 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
372 let temp_dir = TempDir::new()?;
373 let test_file_path = temp_dir.path().join("test_file.txt");
374 let mut test_file = File::create(&test_file_path)?;
375 let content = b"Hello, world!";
376 test_file.write_all(content)?;
377
378 let calculated_checksum = calculate_sha256(&test_file_path)?;
379
380 let checksums_path = temp_dir.path().join("checksums.json");
382 let checksums_data = json!({
383 "test_file.txt": format!("sha256:{}", calculated_checksum)
384 });
385 let checksums_file = File::create(&checksums_path)?;
386 let writer = BufWriter::new(checksums_file);
387 to_writer(writer, &checksums_data)?;
388
389 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
390 assert!(is_valid, "The checksum should be valid");
391 Ok(())
392 }
393}