1use std::{
17 cmp,
18 fmt::Display,
19 fs::{File, OpenOptions},
20 io::{BufReader, BufWriter, Read, copy},
21 path::Path,
22 sync::{Mutex, OnceLock},
23 thread::sleep,
24 time::{Duration, Instant},
25};
26
27use aws_lc_rs::digest::{self, Context};
28use nautilus_network::retry::RetryConfig;
29use rand::{RngExt, rng};
30use reqwest::blocking::Client;
31use serde_json::Value;
32
33static LARGE_CHECKSUMS_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
34
35fn lock_large_checksums() -> anyhow::Result<std::sync::MutexGuard<'static, ()>> {
36 LARGE_CHECKSUMS_LOCK
37 .get_or_init(|| Mutex::new(()))
38 .lock()
39 .map_err(|e| anyhow::anyhow!("Failed to lock checksums file access: {e}"))
40}
41
42#[derive(Debug)]
43enum DownloadError {
44 Retryable(String),
45 NonRetryable(String),
46}
47
48impl Display for DownloadError {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
52 Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
53 }
54 }
55}
56
57impl std::error::Error for DownloadError {}
58
59fn execute_with_retry_blocking<T, E, F>(
60 config: &RetryConfig,
61 mut op: F,
62 should_retry: impl Fn(&E) -> bool,
63) -> Result<T, E>
64where
65 E: std::error::Error,
66 F: FnMut() -> Result<T, E>,
67{
68 let start = Instant::now();
69 let mut delay = Duration::from_millis(config.initial_delay_ms);
70
71 for attempt in 0..=config.max_retries {
72 if attempt > 0 && !config.immediate_first {
73 let jitter = rng().random_range(0..=config.jitter_ms);
74 let sleep_for = delay + Duration::from_millis(jitter);
75 sleep(sleep_for);
76 let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
77 delay = cmp::min(
78 Duration::from_millis(next),
79 Duration::from_millis(config.max_delay_ms),
80 );
81 }
82
83 if let Some(max_total) = config.max_elapsed_ms
84 && start.elapsed() >= Duration::from_millis(max_total)
85 {
86 break;
87 }
88
89 match op() {
90 Ok(v) => return Ok(v),
91 Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
92 Err(e) => return Err(e),
93 }
94 }
95
96 op()
97}
98
99pub fn ensure_file_exists_or_download_http(
119 filepath: &Path,
120 url: &str,
121 checksums: Option<&Path>,
122 timeout_secs: Option<u64>,
123) -> anyhow::Result<()> {
124 ensure_file_exists_or_download_http_with_config(
125 filepath,
126 url,
127 checksums,
128 timeout_secs.unwrap_or(30),
129 None,
130 None,
131 )
132}
133
134pub fn ensure_file_exists_or_download_http_with_timeout(
143 filepath: &Path,
144 url: &str,
145 checksums: Option<&Path>,
146 timeout_secs: u64,
147) -> anyhow::Result<()> {
148 ensure_file_exists_or_download_http_with_config(
149 filepath,
150 url,
151 checksums,
152 timeout_secs,
153 None,
154 None,
155 )
156}
157
158pub fn ensure_file_exists_or_download_http_with_config(
177 filepath: &Path,
178 url: &str,
179 checksums: Option<&Path>,
180 timeout_secs: u64,
181 retry_config: Option<RetryConfig>,
182 initial_jitter_ms: Option<u64>,
183) -> anyhow::Result<()> {
184 if filepath.exists() {
185 println!("File already exists: {filepath:?}");
186
187 if let Some(checksums_file) = checksums {
188 let _guard = lock_large_checksums()?;
189 if verify_sha256_checksum(filepath, checksums_file)? {
190 println!("File is valid");
191 return Ok(());
192 } else {
193 let new_checksum = calculate_sha256(filepath)?;
194 println!("Adding checksum for existing file: {new_checksum}");
195 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
196 return Ok(());
197 }
198 }
199 return Ok(());
200 }
201
202 if let Some(jitter_ms) = initial_jitter_ms {
205 if jitter_ms > 0 {
206 sleep(Duration::from_millis(jitter_ms));
207 }
208 } else {
209 let jitter_delay = {
210 let mut r = rng();
211 Duration::from_millis(r.random_range(100..=600))
212 };
213 sleep(jitter_delay);
214 }
215
216 download_file(filepath, url, timeout_secs, retry_config)?;
217
218 if let Some(checksums_file) = checksums {
219 let _guard = lock_large_checksums()?;
220 if !verify_sha256_checksum(filepath, checksums_file)? {
221 let new_checksum = calculate_sha256(filepath)?;
222 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
223 }
224 }
225
226 Ok(())
227}
228
229fn download_file(
230 filepath: &Path,
231 url: &str,
232 timeout_secs: u64,
233 retry_config: Option<RetryConfig>,
234) -> anyhow::Result<()> {
235 #[cfg(not(test))]
239 if !url.starts_with("https://") {
240 anyhow::bail!("URL must use HTTPS protocol for security: {url}");
241 }
242
243 println!("Downloading file from {url} to {filepath:?}");
244
245 if let Some(parent) = filepath.parent() {
246 std::fs::create_dir_all(parent)?;
247 }
248
249 let client = Client::builder()
250 .timeout(Duration::from_secs(timeout_secs))
251 .build()?;
252
253 let cfg = if let Some(config) = retry_config {
254 config
255 } else {
256 let max_retries = 5u32;
258 let op_timeout_ms = timeout_secs.saturating_mul(1000);
259 let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
262 RetryConfig {
263 max_retries,
264 initial_delay_ms: 1_000,
265 max_delay_ms: 10_000,
266 backoff_factor: 2.0,
267 jitter_ms: 1_000,
268 operation_timeout_ms: Some(per_attempt_ms),
269 immediate_first: false,
270 max_elapsed_ms: Some(op_timeout_ms),
271 }
272 };
273
274 let op = || -> Result<(), DownloadError> {
275 match client.get(url).send() {
276 Ok(mut response) => {
277 let status = response.status();
278 if status.is_success() {
279 let mut out = File::create(filepath)
280 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
281 copy(&mut response, &mut out)
283 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
284 println!("File downloaded to {filepath:?}");
285 Ok(())
286 } else if status.is_server_error()
287 || status.as_u16() == 429
288 || status.as_u16() == 408
289 {
290 println!("HTTP error {status}, retrying...");
291 Err(DownloadError::Retryable(format!("HTTP {status}")))
292 } else {
293 Err(DownloadError::NonRetryable(format!(
295 "Client error: HTTP {status}"
296 )))
297 }
298 }
299 Err(e) => {
300 println!("Request failed: {e}");
301 Err(DownloadError::Retryable(e.to_string()))
302 }
303 }
304 };
305
306 let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
307
308 execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
309}
310
311fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
312 let mut file = File::open(filepath)?;
313 let mut ctx = Context::new(&digest::SHA256);
314 let mut buffer = [0u8; 4096];
315
316 loop {
317 let count = file.read(&mut buffer)?;
318 if count == 0 {
319 break;
320 }
321 ctx.update(&buffer[..count]);
322 }
323
324 let digest = ctx.finish();
325 Ok(hex::encode(digest.as_ref()))
326}
327
328fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
329 let file = File::open(checksums)?;
330 let reader = BufReader::new(file);
331 let checksums: Value = serde_json::from_reader(reader)?;
332
333 let filename = filepath.file_name().unwrap().to_str().unwrap();
334 if let Some(expected_checksum) = checksums.get(filename) {
335 let expected_checksum_str = expected_checksum.as_str().unwrap();
336 let expected_hash = expected_checksum_str
337 .strip_prefix("sha256:")
338 .unwrap_or(expected_checksum_str);
339 let calculated_checksum = calculate_sha256(filepath)?;
340 if expected_hash == calculated_checksum {
341 return Ok(true);
342 }
343 }
344
345 Ok(false)
346}
347
348fn update_sha256_checksums(
349 filepath: &Path,
350 checksums_file: &Path,
351 new_checksum: &str,
352) -> anyhow::Result<()> {
353 let checksums: Value = if checksums_file.exists() {
354 let file = File::open(checksums_file)?;
355 let reader = BufReader::new(file);
356 serde_json::from_reader(reader)?
357 } else {
358 serde_json::json!({})
359 };
360
361 let mut checksums_map = checksums.as_object().unwrap().clone();
362
363 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
365 let prefixed_checksum = format!("sha256:{new_checksum}");
366 checksums_map.insert(filename, Value::String(prefixed_checksum));
367
368 let file = OpenOptions::new()
369 .write(true)
370 .create(true)
371 .truncate(true)
372 .open(checksums_file)?;
373 let writer = BufWriter::new(file);
374 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
375
376 Ok(())
377}
378
379#[cfg(test)]
380mod tests {
381 use std::{
382 fs,
383 io::{BufWriter, Write},
384 net::SocketAddr,
385 sync::{
386 Arc,
387 atomic::{AtomicUsize, Ordering},
388 },
389 };
390
391 use axum::{Router, http::StatusCode, routing::get, serve};
392 use rstest::*;
393 use serde_json::{json, to_writer};
394 use tempfile::TempDir;
395 use tokio::{
396 net::TcpListener,
397 task,
398 time::{Duration, sleep},
399 };
400
401 use super::*;
402
403 fn test_retry_config() -> RetryConfig {
406 RetryConfig {
407 max_retries: 5,
408 initial_delay_ms: 10,
409 max_delay_ms: 50,
410 backoff_factor: 2.0,
411 jitter_ms: 5,
412 operation_timeout_ms: Some(500),
413 immediate_first: false,
414 max_elapsed_ms: Some(2000),
415 }
416 }
417
418 async fn setup_test_server(
419 server_content: Option<String>,
420 status_code: StatusCode,
421 ) -> SocketAddr {
422 let server_content = Arc::new(server_content);
423 let server_content_clone = server_content.clone();
424 let app = Router::new().route(
425 "/testfile.txt",
426 get(move || {
427 let server_content = server_content_clone.clone();
428 async move {
429 let response_body = match &*server_content {
430 Some(content) => content.clone(),
431 None => "File not found".to_string(),
432 };
433 (status_code, response_body)
434 }
435 }),
436 );
437
438 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
439 let addr = listener.local_addr().unwrap();
440 let server = serve(listener, app);
441
442 task::spawn(async move {
443 if let Err(e) = server.await {
444 eprintln!("server error: {e}");
445 }
446 });
447
448 sleep(Duration::from_millis(100)).await;
449
450 addr
451 }
452
453 #[tokio::test]
454 async fn test_file_already_exists() {
455 let temp_dir = TempDir::new().unwrap();
456 let file_path = temp_dir.path().join("testfile.txt");
457 fs::write(&file_path, "Existing file content").unwrap();
458
459 let url = "http://example.com/testfile.txt".to_string();
460 let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
461
462 assert!(result.is_ok());
463 let content = fs::read_to_string(&file_path).unwrap();
464 assert_eq!(content, "Existing file content");
465 }
466
467 #[tokio::test]
468 async fn test_download_file_success() {
469 let temp_dir = TempDir::new().unwrap();
470 let filepath = temp_dir.path().join("testfile.txt");
471 let filepath_clone = filepath.clone();
472
473 let server_content = "Server file content".to_string();
474 let status_code = StatusCode::OK;
475 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
476 let url = format!("http://{addr}/testfile.txt");
477
478 let result = tokio::task::spawn_blocking(move || {
479 ensure_file_exists_or_download_http_with_config(
480 &filepath_clone,
481 &url,
482 None,
483 5,
484 Some(test_retry_config()),
485 Some(0),
486 )
487 })
488 .await
489 .unwrap();
490
491 assert!(result.is_ok());
492 let content = fs::read_to_string(&filepath).unwrap();
493 assert_eq!(content, server_content);
494 }
495
496 #[tokio::test]
497 async fn test_download_file_not_found() {
498 let temp_dir = TempDir::new().unwrap();
499 let file_path = temp_dir.path().join("testfile.txt");
500
501 let server_content = None;
502 let status_code = StatusCode::NOT_FOUND;
503 let addr = setup_test_server(server_content, status_code).await;
504 let url = format!("http://{addr}/testfile.txt");
505
506 let result = tokio::task::spawn_blocking(move || {
507 ensure_file_exists_or_download_http_with_config(
508 &file_path,
509 &url,
510 None,
511 1,
512 Some(test_retry_config()),
513 Some(0),
514 )
515 })
516 .await
517 .unwrap();
518
519 assert!(result.is_err());
520 let err_msg = format!("{}", result.unwrap_err());
521 assert!(
522 err_msg.contains("Client error: HTTP"),
523 "Unexpected error message: {err_msg}"
524 );
525 }
526
527 #[tokio::test]
528 async fn test_network_error() {
529 let temp_dir = TempDir::new().unwrap();
530 let file_path = temp_dir.path().join("testfile.txt");
531
532 let url = "http://127.0.0.1:0/testfile.txt".to_string();
534
535 let result = tokio::task::spawn_blocking(move || {
536 ensure_file_exists_or_download_http_with_config(
537 &file_path,
538 &url,
539 None,
540 2,
541 Some(test_retry_config()),
542 Some(0),
543 )
544 })
545 .await
546 .unwrap();
547
548 assert!(result.is_err());
549 let err_msg = format!("{}", result.unwrap_err());
550 assert!(
551 err_msg.contains("error"),
552 "Unexpected error message: {err_msg}"
553 );
554 }
555
556 #[tokio::test]
557 async fn test_retry_then_success_on_500() {
558 let temp_dir = TempDir::new().unwrap();
559 let filepath = temp_dir.path().join("testfile.txt");
560 let filepath_clone = filepath.clone();
561
562 let counter = Arc::new(AtomicUsize::new(0));
563 let counter_clone = counter.clone();
564
565 let app = Router::new().route(
566 "/testfile.txt",
567 get(move || {
568 let c = counter_clone.clone();
569 async move {
570 let n = c.fetch_add(1, Ordering::SeqCst);
571 if n < 2 {
572 (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
573 } else {
574 (StatusCode::OK, "eventual success")
575 }
576 }
577 }),
578 );
579
580 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
581 let addr = listener.local_addr().unwrap();
582 let server = serve(listener, app);
583 task::spawn(async move {
584 let _ = server.await;
585 });
586 sleep(Duration::from_millis(100)).await;
587
588 let url = format!("http://{addr}/testfile.txt");
589
590 let result = tokio::task::spawn_blocking(move || {
591 ensure_file_exists_or_download_http_with_config(
592 &filepath_clone,
593 &url,
594 None,
595 5,
596 Some(test_retry_config()),
597 Some(0),
598 )
599 })
600 .await
601 .unwrap();
602
603 assert!(result.is_ok());
604 let content = std::fs::read_to_string(&filepath).unwrap();
605 assert_eq!(content, "eventual success");
606 assert!(counter.load(Ordering::SeqCst) >= 2);
607 }
608
609 #[tokio::test]
610 async fn test_retry_then_success_on_429() {
611 let temp_dir = TempDir::new().unwrap();
612 let filepath = temp_dir.path().join("testfile.txt");
613 let filepath_clone = filepath.clone();
614
615 let counter = Arc::new(AtomicUsize::new(0));
616 let counter_clone = counter.clone();
617
618 let app = Router::new().route(
619 "/testfile.txt",
620 get(move || {
621 let c = counter_clone.clone();
622 async move {
623 let n = c.fetch_add(1, Ordering::SeqCst);
624 if n < 1 {
625 (StatusCode::TOO_MANY_REQUESTS, "rate limited")
626 } else {
627 (StatusCode::OK, "ok after retry")
628 }
629 }
630 }),
631 );
632
633 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
634 let addr = listener.local_addr().unwrap();
635 let server = serve(listener, app);
636 task::spawn(async move {
637 let _ = server.await;
638 });
639 sleep(Duration::from_millis(100)).await;
640
641 let url = format!("http://{addr}/testfile.txt");
642
643 let result = tokio::task::spawn_blocking(move || {
644 ensure_file_exists_or_download_http_with_config(
645 &filepath_clone,
646 &url,
647 None,
648 5,
649 Some(test_retry_config()),
650 Some(0),
651 )
652 })
653 .await
654 .unwrap();
655
656 assert!(result.is_ok());
657 let content = std::fs::read_to_string(&filepath).unwrap();
658 assert_eq!(content, "ok after retry");
659 assert!(counter.load(Ordering::SeqCst) >= 2);
660 }
661
662 #[tokio::test]
663 async fn test_no_retry_on_404() {
664 let temp_dir = TempDir::new().unwrap();
665 let filepath = temp_dir.path().join("testfile.txt");
666 let filepath_clone = filepath.clone();
667
668 let counter = Arc::new(AtomicUsize::new(0));
669 let counter_clone = counter.clone();
670
671 let app = Router::new().route(
672 "/testfile.txt",
673 get(move || {
674 let c = counter_clone.clone();
675 async move {
676 c.fetch_add(1, Ordering::SeqCst);
677 (StatusCode::NOT_FOUND, "missing")
678 }
679 }),
680 );
681
682 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
683 let addr = listener.local_addr().unwrap();
684 let server = serve(listener, app);
685 task::spawn(async move {
686 let _ = server.await;
687 });
688 sleep(Duration::from_millis(100)).await;
689
690 let url = format!("http://{addr}/testfile.txt");
691
692 let result = tokio::task::spawn_blocking(move || {
693 ensure_file_exists_or_download_http_with_config(
694 &filepath_clone,
695 &url,
696 None,
697 5,
698 Some(test_retry_config()),
699 Some(0),
700 )
701 })
702 .await
703 .unwrap();
704
705 assert!(result.is_err());
706 assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
707 }
708
709 #[rstest]
710 #[allow(clippy::panic_in_result_fn)]
711 fn test_calculate_sha256() -> anyhow::Result<()> {
712 let temp_dir = TempDir::new()?;
713 let test_file_path = temp_dir.path().join("test_file.txt");
714 let mut test_file = File::create(&test_file_path)?;
715 let content = b"Hello, world!";
716 test_file.write_all(content)?;
717
718 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
719 let calculated_hash = calculate_sha256(&test_file_path)?;
720
721 assert_eq!(calculated_hash, expected_hash);
722 Ok(())
723 }
724
725 #[rstest]
726 #[allow(clippy::panic_in_result_fn)]
727 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
728 let temp_dir = TempDir::new()?;
729 let test_file_path = temp_dir.path().join("test_file.txt");
730 let mut test_file = File::create(&test_file_path)?;
731 let content = b"Hello, world!";
732 test_file.write_all(content)?;
733
734 let calculated_checksum = calculate_sha256(&test_file_path)?;
735
736 let checksums_path = temp_dir.path().join("checksums.json");
738 let checksums_data = json!({
739 "test_file.txt": format!("sha256:{}", calculated_checksum)
740 });
741 let checksums_file = File::create(&checksums_path)?;
742 let writer = BufWriter::new(checksums_file);
743 to_writer(writer, &checksums_data)?;
744
745 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
746 assert!(is_valid, "The checksum should be valid");
747 Ok(())
748 }
749}