1use std::{
17 cmp,
18 fs::{File, OpenOptions},
19 io::{BufReader, BufWriter, Read, copy},
20 path::Path,
21 thread::sleep,
22 time::{Duration, Instant},
23};
24
25use aws_lc_rs::digest::{self, Context};
26use nautilus_network::retry::RetryConfig;
27use rand::{Rng, rng};
28use reqwest::blocking::Client;
29use serde_json::Value;
30
31#[derive(Debug)]
32enum DownloadError {
33 Retryable(String),
34 NonRetryable(String),
35}
36
37impl std::fmt::Display for DownloadError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
41 Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
42 }
43 }
44}
45
46impl std::error::Error for DownloadError {}
47
48fn execute_with_retry_blocking<T, E, F>(
49 config: &RetryConfig,
50 mut op: F,
51 should_retry: impl Fn(&E) -> bool,
52) -> Result<T, E>
53where
54 E: std::error::Error,
55 F: FnMut() -> Result<T, E>,
56{
57 let start = Instant::now();
58 let mut delay = Duration::from_millis(config.initial_delay_ms);
59
60 for attempt in 0..=config.max_retries {
61 if attempt > 0 && !config.immediate_first {
62 let jitter = rng().random_range(0..=config.jitter_ms);
63 let sleep_for = delay + Duration::from_millis(jitter);
64 sleep(sleep_for);
65 let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
66 delay = cmp::min(
67 Duration::from_millis(next),
68 Duration::from_millis(config.max_delay_ms),
69 );
70 }
71
72 if let Some(max_total) = config.max_elapsed_ms
73 && start.elapsed() >= Duration::from_millis(max_total)
74 {
75 break;
76 }
77
78 match op() {
79 Ok(v) => return Ok(v),
80 Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
81 Err(e) => return Err(e),
82 }
83 }
84
85 op()
86}
87
88pub fn ensure_file_exists_or_download_http(
108 filepath: &Path,
109 url: &str,
110 checksums: Option<&Path>,
111 timeout_secs: Option<u64>,
112) -> anyhow::Result<()> {
113 ensure_file_exists_or_download_http_with_timeout(
114 filepath,
115 url,
116 checksums,
117 timeout_secs.unwrap_or(30),
118 )
119}
120
121pub fn ensure_file_exists_or_download_http_with_timeout(
130 filepath: &Path,
131 url: &str,
132 checksums: Option<&Path>,
133 timeout_secs: u64,
134) -> anyhow::Result<()> {
135 if filepath.exists() {
136 println!("File already exists: {filepath:?}");
137
138 if let Some(checksums_file) = checksums {
139 if verify_sha256_checksum(filepath, checksums_file)? {
140 println!("File is valid");
141 return Ok(());
142 } else {
143 let new_checksum = calculate_sha256(filepath)?;
144 println!("Adding checksum for existing file: {new_checksum}");
145 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
146 return Ok(());
147 }
148 }
149 return Ok(());
150 }
151
152 let jitter_delay = {
156 let mut r = rng();
157 Duration::from_millis(r.random_range(100..=600))
158 };
159 sleep(jitter_delay);
160
161 download_file(filepath, url, timeout_secs)?;
162
163 if let Some(checksums_file) = checksums {
164 let new_checksum = calculate_sha256(filepath)?;
165 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
166 }
167
168 Ok(())
169}
170
171fn download_file(filepath: &Path, url: &str, timeout_secs: u64) -> anyhow::Result<()> {
172 println!("Downloading file from {url} to {filepath:?}");
173
174 if let Some(parent) = filepath.parent() {
175 std::fs::create_dir_all(parent)?;
176 }
177
178 let client = Client::builder()
179 .timeout(Duration::from_secs(timeout_secs))
180 .build()?;
181
182 let max_retries = 5u32;
183 let op_timeout_ms = timeout_secs.saturating_mul(1000);
184 let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
187 let cfg = RetryConfig {
188 max_retries,
189 initial_delay_ms: 1_000,
190 max_delay_ms: 10_000,
191 backoff_factor: 2.0,
192 jitter_ms: 1_000,
193 operation_timeout_ms: Some(per_attempt_ms),
194 immediate_first: false,
195 max_elapsed_ms: Some(op_timeout_ms),
196 };
197
198 let op = || -> Result<(), DownloadError> {
199 match client.get(url).send() {
200 Ok(mut response) => {
201 let status = response.status();
202 if status.is_success() {
203 let mut out = File::create(filepath)
204 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
205 copy(&mut response, &mut out)
207 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
208 println!("File downloaded to {filepath:?}");
209 Ok(())
210 } else if status.is_server_error()
211 || status.as_u16() == 429
212 || status.as_u16() == 408
213 {
214 println!("HTTP error {status}, retrying...");
215 Err(DownloadError::Retryable(format!("HTTP {status}")))
216 } else {
217 Err(DownloadError::NonRetryable(format!(
219 "Client error: HTTP {status}"
220 )))
221 }
222 }
223 Err(e) => {
224 println!("Request failed: {e}");
225 Err(DownloadError::Retryable(e.to_string()))
226 }
227 }
228 };
229
230 let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
231
232 execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
233}
234
235fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
236 let mut file = File::open(filepath)?;
237 let mut ctx = Context::new(&digest::SHA256);
238 let mut buffer = [0u8; 4096];
239
240 loop {
241 let count = file.read(&mut buffer)?;
242 if count == 0 {
243 break;
244 }
245 ctx.update(&buffer[..count]);
246 }
247
248 let digest = ctx.finish();
249 Ok(hex::encode(digest.as_ref()))
250}
251
252fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
253 let file = File::open(checksums)?;
254 let reader = BufReader::new(file);
255 let checksums: Value = serde_json::from_reader(reader)?;
256
257 let filename = filepath.file_name().unwrap().to_str().unwrap();
258 if let Some(expected_checksum) = checksums.get(filename) {
259 let expected_checksum_str = expected_checksum.as_str().unwrap();
260 let expected_hash = expected_checksum_str
261 .strip_prefix("sha256:")
262 .unwrap_or(expected_checksum_str);
263 let calculated_checksum = calculate_sha256(filepath)?;
264 if expected_hash == calculated_checksum {
265 return Ok(true);
266 }
267 }
268
269 Ok(false)
270}
271
272fn update_sha256_checksums(
273 filepath: &Path,
274 checksums_file: &Path,
275 new_checksum: &str,
276) -> anyhow::Result<()> {
277 let checksums: Value = if checksums_file.exists() {
278 let file = File::open(checksums_file)?;
279 let reader = BufReader::new(file);
280 serde_json::from_reader(reader)?
281 } else {
282 serde_json::json!({})
283 };
284
285 let mut checksums_map = checksums.as_object().unwrap().clone();
286
287 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
289 let prefixed_checksum = format!("sha256:{new_checksum}");
290 checksums_map.insert(filename, Value::String(prefixed_checksum));
291
292 let file = OpenOptions::new()
293 .write(true)
294 .create(true)
295 .truncate(true)
296 .open(checksums_file)?;
297 let writer = BufWriter::new(file);
298 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
299
300 Ok(())
301}
302
303#[cfg(test)]
307mod tests {
308 use std::{
309 fs,
310 io::{BufWriter, Write},
311 net::SocketAddr,
312 sync::{
313 Arc,
314 atomic::{AtomicUsize, Ordering},
315 },
316 };
317
318 use axum::{Router, http::StatusCode, routing::get, serve};
319 use rstest::*;
320 use serde_json::{json, to_writer};
321 use tempfile::TempDir;
322 use tokio::{
323 net::TcpListener,
324 task,
325 time::{Duration, sleep},
326 };
327
328 use super::*;
329
330 async fn setup_test_server(
331 server_content: Option<String>,
332 status_code: StatusCode,
333 ) -> SocketAddr {
334 let server_content = Arc::new(server_content);
335 let server_content_clone = server_content.clone();
336 let app = Router::new().route(
337 "/testfile.txt",
338 get(move || {
339 let server_content = server_content_clone.clone();
340 async move {
341 let response_body = match &*server_content {
342 Some(content) => content.clone(),
343 None => "File not found".to_string(),
344 };
345 (status_code, response_body)
346 }
347 }),
348 );
349
350 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
351 let addr = listener.local_addr().unwrap();
352 let server = serve(listener, app);
353
354 task::spawn(async move {
355 if let Err(e) = server.await {
356 eprintln!("server error: {e}");
357 }
358 });
359
360 sleep(Duration::from_millis(100)).await;
361
362 addr
363 }
364
365 #[tokio::test]
366 async fn test_file_already_exists() {
367 let temp_dir = TempDir::new().unwrap();
368 let file_path = temp_dir.path().join("testfile.txt");
369 fs::write(&file_path, "Existing file content").unwrap();
370
371 let url = "http://example.com/testfile.txt".to_string();
372 let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
373
374 assert!(result.is_ok());
375 let content = fs::read_to_string(&file_path).unwrap();
376 assert_eq!(content, "Existing file content");
377 }
378
379 #[tokio::test]
380 async fn test_download_file_success() {
381 let temp_dir = TempDir::new().unwrap();
382 let filepath = temp_dir.path().join("testfile.txt");
383 let filepath_clone = filepath.clone();
384
385 let server_content = "Server file content".to_string();
386 let status_code = StatusCode::OK;
387 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
388 let url = format!("http://{addr}/testfile.txt");
389
390 let result = tokio::task::spawn_blocking(move || {
391 ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
392 })
393 .await
394 .unwrap();
395
396 assert!(result.is_ok());
397 let content = fs::read_to_string(&filepath).unwrap();
398 assert_eq!(content, server_content);
399 }
400
401 #[tokio::test]
402 async fn test_download_file_not_found() {
403 let temp_dir = TempDir::new().unwrap();
404 let file_path = temp_dir.path().join("testfile.txt");
405
406 let server_content = None;
407 let status_code = StatusCode::NOT_FOUND;
408 let addr = setup_test_server(server_content, status_code).await;
409 let url = format!("http://{addr}/testfile.txt");
410
411 let result = tokio::task::spawn_blocking(move || {
412 ensure_file_exists_or_download_http_with_timeout(&file_path, &url, None, 1)
413 })
414 .await
415 .unwrap();
416
417 assert!(result.is_err());
418 let err_msg = format!("{}", result.unwrap_err());
419 assert!(
420 err_msg.contains("Client error: HTTP"),
421 "Unexpected error message: {err_msg}"
422 );
423 }
424
425 #[tokio::test]
426 async fn test_network_error() {
427 let temp_dir = TempDir::new().unwrap();
428 let file_path = temp_dir.path().join("testfile.txt");
429
430 let url = "http://127.0.0.1:0/testfile.txt".to_string();
432
433 let result = tokio::task::spawn_blocking(move || {
434 ensure_file_exists_or_download_http(&file_path, &url, None, Some(2))
435 })
436 .await
437 .unwrap();
438
439 assert!(result.is_err());
440 let err_msg = format!("{}", result.unwrap_err());
441 assert!(
442 err_msg.contains("error"),
443 "Unexpected error message: {err_msg}"
444 );
445 }
446
447 #[tokio::test]
448 async fn test_retry_then_success_on_500() {
449 let temp_dir = TempDir::new().unwrap();
450 let filepath = temp_dir.path().join("testfile.txt");
451 let filepath_clone = filepath.clone();
452
453 let counter = Arc::new(AtomicUsize::new(0));
454 let counter_clone = counter.clone();
455
456 let app = Router::new().route(
457 "/testfile.txt",
458 get(move || {
459 let c = counter_clone.clone();
460 async move {
461 let n = c.fetch_add(1, Ordering::SeqCst);
462 if n < 2 {
463 (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
464 } else {
465 (StatusCode::OK, "eventual success")
466 }
467 }
468 }),
469 );
470
471 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
472 let addr = listener.local_addr().unwrap();
473 let server = serve(listener, app);
474 task::spawn(async move {
475 let _ = server.await;
476 });
477 sleep(Duration::from_millis(100)).await;
478
479 let url = format!("http://{addr}/testfile.txt");
480 let result = tokio::task::spawn_blocking(move || {
481 ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
482 })
483 .await
484 .unwrap();
485
486 assert!(result.is_ok());
487 let content = std::fs::read_to_string(&filepath).unwrap();
488 assert_eq!(content, "eventual success");
489 assert!(counter.load(Ordering::SeqCst) >= 2);
490 }
491
492 #[tokio::test]
493 async fn test_retry_then_success_on_429() {
494 let temp_dir = TempDir::new().unwrap();
495 let filepath = temp_dir.path().join("testfile.txt");
496 let filepath_clone = filepath.clone();
497
498 let counter = Arc::new(AtomicUsize::new(0));
499 let counter_clone = counter.clone();
500
501 let app = Router::new().route(
502 "/testfile.txt",
503 get(move || {
504 let c = counter_clone.clone();
505 async move {
506 let n = c.fetch_add(1, Ordering::SeqCst);
507 if n < 1 {
508 (StatusCode::TOO_MANY_REQUESTS, "rate limited")
509 } else {
510 (StatusCode::OK, "ok after retry")
511 }
512 }
513 }),
514 );
515
516 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
517 let addr = listener.local_addr().unwrap();
518 let server = serve(listener, app);
519 task::spawn(async move {
520 let _ = server.await;
521 });
522 sleep(Duration::from_millis(100)).await;
523
524 let url = format!("http://{addr}/testfile.txt");
525 let result = tokio::task::spawn_blocking(move || {
526 ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
527 })
528 .await
529 .unwrap();
530
531 assert!(result.is_ok());
532 let content = std::fs::read_to_string(&filepath).unwrap();
533 assert_eq!(content, "ok after retry");
534 assert!(counter.load(Ordering::SeqCst) >= 2);
535 }
536
537 #[tokio::test]
538 async fn test_no_retry_on_404() {
539 let temp_dir = TempDir::new().unwrap();
540 let filepath = temp_dir.path().join("testfile.txt");
541 let filepath_clone = filepath.clone();
542
543 let counter = Arc::new(AtomicUsize::new(0));
544 let counter_clone = counter.clone();
545
546 let app = Router::new().route(
547 "/testfile.txt",
548 get(move || {
549 let c = counter_clone.clone();
550 async move {
551 c.fetch_add(1, Ordering::SeqCst);
552 (StatusCode::NOT_FOUND, "missing")
553 }
554 }),
555 );
556
557 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
558 let addr = listener.local_addr().unwrap();
559 let server = serve(listener, app);
560 task::spawn(async move {
561 let _ = server.await;
562 });
563 sleep(Duration::from_millis(100)).await;
564
565 let url = format!("http://{addr}/testfile.txt");
566 let result = tokio::task::spawn_blocking(move || {
567 ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
568 })
569 .await
570 .unwrap();
571
572 assert!(result.is_err());
573 assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
574 }
575
576 #[rstest]
577 fn test_calculate_sha256() -> anyhow::Result<()> {
578 let temp_dir = TempDir::new()?;
579 let test_file_path = temp_dir.path().join("test_file.txt");
580 let mut test_file = File::create(&test_file_path)?;
581 let content = b"Hello, world!";
582 test_file.write_all(content)?;
583
584 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
585 let calculated_hash = calculate_sha256(&test_file_path)?;
586
587 assert_eq!(calculated_hash, expected_hash);
588 Ok(())
589 }
590
591 #[rstest]
592 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
593 let temp_dir = TempDir::new()?;
594 let test_file_path = temp_dir.path().join("test_file.txt");
595 let mut test_file = File::create(&test_file_path)?;
596 let content = b"Hello, world!";
597 test_file.write_all(content)?;
598
599 let calculated_checksum = calculate_sha256(&test_file_path)?;
600
601 let checksums_path = temp_dir.path().join("checksums.json");
603 let checksums_data = json!({
604 "test_file.txt": format!("sha256:{}", calculated_checksum)
605 });
606 let checksums_file = File::create(&checksums_path)?;
607 let writer = BufWriter::new(checksums_file);
608 to_writer(writer, &checksums_data)?;
609
610 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
611 assert!(is_valid, "The checksum should be valid");
612 Ok(())
613 }
614}