1use std::{
17 cmp,
18 fmt::Display,
19 fs::{File, OpenOptions},
20 io::{BufReader, BufWriter, Read, copy},
21 path::Path,
22 thread::sleep,
23 time::{Duration, Instant},
24};
25
26use aws_lc_rs::digest::{self, Context};
27use nautilus_network::retry::RetryConfig;
28use rand::{Rng, rng};
29use reqwest::blocking::Client;
30use serde_json::Value;
31
32#[derive(Debug)]
33enum DownloadError {
34 Retryable(String),
35 NonRetryable(String),
36}
37
38impl Display for DownloadError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
42 Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
43 }
44 }
45}
46
47impl std::error::Error for DownloadError {}
48
49fn execute_with_retry_blocking<T, E, F>(
50 config: &RetryConfig,
51 mut op: F,
52 should_retry: impl Fn(&E) -> bool,
53) -> Result<T, E>
54where
55 E: std::error::Error,
56 F: FnMut() -> Result<T, E>,
57{
58 let start = Instant::now();
59 let mut delay = Duration::from_millis(config.initial_delay_ms);
60
61 for attempt in 0..=config.max_retries {
62 if attempt > 0 && !config.immediate_first {
63 let jitter = rng().random_range(0..=config.jitter_ms);
64 let sleep_for = delay + Duration::from_millis(jitter);
65 sleep(sleep_for);
66 let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
67 delay = cmp::min(
68 Duration::from_millis(next),
69 Duration::from_millis(config.max_delay_ms),
70 );
71 }
72
73 if let Some(max_total) = config.max_elapsed_ms
74 && start.elapsed() >= Duration::from_millis(max_total)
75 {
76 break;
77 }
78
79 match op() {
80 Ok(v) => return Ok(v),
81 Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
82 Err(e) => return Err(e),
83 }
84 }
85
86 op()
87}
88
89pub fn ensure_file_exists_or_download_http(
109 filepath: &Path,
110 url: &str,
111 checksums: Option<&Path>,
112 timeout_secs: Option<u64>,
113) -> anyhow::Result<()> {
114 ensure_file_exists_or_download_http_with_config(
115 filepath,
116 url,
117 checksums,
118 timeout_secs.unwrap_or(30),
119 None,
120 None,
121 )
122}
123
124pub fn ensure_file_exists_or_download_http_with_timeout(
133 filepath: &Path,
134 url: &str,
135 checksums: Option<&Path>,
136 timeout_secs: u64,
137) -> anyhow::Result<()> {
138 ensure_file_exists_or_download_http_with_config(
139 filepath,
140 url,
141 checksums,
142 timeout_secs,
143 None,
144 None,
145 )
146}
147
148pub fn ensure_file_exists_or_download_http_with_config(
167 filepath: &Path,
168 url: &str,
169 checksums: Option<&Path>,
170 timeout_secs: u64,
171 retry_config: Option<RetryConfig>,
172 initial_jitter_ms: Option<u64>,
173) -> anyhow::Result<()> {
174 if filepath.exists() {
175 println!("File already exists: {filepath:?}");
176
177 if let Some(checksums_file) = checksums {
178 if verify_sha256_checksum(filepath, checksums_file)? {
179 println!("File is valid");
180 return Ok(());
181 } else {
182 let new_checksum = calculate_sha256(filepath)?;
183 println!("Adding checksum for existing file: {new_checksum}");
184 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
185 return Ok(());
186 }
187 }
188 return Ok(());
189 }
190
191 if let Some(jitter_ms) = initial_jitter_ms {
194 if jitter_ms > 0 {
195 sleep(Duration::from_millis(jitter_ms));
196 }
197 } else {
198 let jitter_delay = {
199 let mut r = rng();
200 Duration::from_millis(r.random_range(100..=600))
201 };
202 sleep(jitter_delay);
203 }
204
205 download_file(filepath, url, timeout_secs, retry_config)?;
206
207 if let Some(checksums_file) = checksums {
208 let new_checksum = calculate_sha256(filepath)?;
209 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
210 }
211
212 Ok(())
213}
214
215fn download_file(
216 filepath: &Path,
217 url: &str,
218 timeout_secs: u64,
219 retry_config: Option<RetryConfig>,
220) -> anyhow::Result<()> {
221 println!("Downloading file from {url} to {filepath:?}");
222
223 if let Some(parent) = filepath.parent() {
224 std::fs::create_dir_all(parent)?;
225 }
226
227 let client = Client::builder()
228 .timeout(Duration::from_secs(timeout_secs))
229 .build()?;
230
231 let cfg = if let Some(config) = retry_config {
232 config
233 } else {
234 let max_retries = 5u32;
236 let op_timeout_ms = timeout_secs.saturating_mul(1000);
237 let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
240 RetryConfig {
241 max_retries,
242 initial_delay_ms: 1_000,
243 max_delay_ms: 10_000,
244 backoff_factor: 2.0,
245 jitter_ms: 1_000,
246 operation_timeout_ms: Some(per_attempt_ms),
247 immediate_first: false,
248 max_elapsed_ms: Some(op_timeout_ms),
249 }
250 };
251
252 let op = || -> Result<(), DownloadError> {
253 match client.get(url).send() {
254 Ok(mut response) => {
255 let status = response.status();
256 if status.is_success() {
257 let mut out = File::create(filepath)
258 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
259 copy(&mut response, &mut out)
261 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
262 println!("File downloaded to {filepath:?}");
263 Ok(())
264 } else if status.is_server_error()
265 || status.as_u16() == 429
266 || status.as_u16() == 408
267 {
268 println!("HTTP error {status}, retrying...");
269 Err(DownloadError::Retryable(format!("HTTP {status}")))
270 } else {
271 Err(DownloadError::NonRetryable(format!(
273 "Client error: HTTP {status}"
274 )))
275 }
276 }
277 Err(e) => {
278 println!("Request failed: {e}");
279 Err(DownloadError::Retryable(e.to_string()))
280 }
281 }
282 };
283
284 let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
285
286 execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
287}
288
289fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
290 let mut file = File::open(filepath)?;
291 let mut ctx = Context::new(&digest::SHA256);
292 let mut buffer = [0u8; 4096];
293
294 loop {
295 let count = file.read(&mut buffer)?;
296 if count == 0 {
297 break;
298 }
299 ctx.update(&buffer[..count]);
300 }
301
302 let digest = ctx.finish();
303 Ok(hex::encode(digest.as_ref()))
304}
305
306fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
307 let file = File::open(checksums)?;
308 let reader = BufReader::new(file);
309 let checksums: Value = serde_json::from_reader(reader)?;
310
311 let filename = filepath.file_name().unwrap().to_str().unwrap();
312 if let Some(expected_checksum) = checksums.get(filename) {
313 let expected_checksum_str = expected_checksum.as_str().unwrap();
314 let expected_hash = expected_checksum_str
315 .strip_prefix("sha256:")
316 .unwrap_or(expected_checksum_str);
317 let calculated_checksum = calculate_sha256(filepath)?;
318 if expected_hash == calculated_checksum {
319 return Ok(true);
320 }
321 }
322
323 Ok(false)
324}
325
326fn update_sha256_checksums(
327 filepath: &Path,
328 checksums_file: &Path,
329 new_checksum: &str,
330) -> anyhow::Result<()> {
331 let checksums: Value = if checksums_file.exists() {
332 let file = File::open(checksums_file)?;
333 let reader = BufReader::new(file);
334 serde_json::from_reader(reader)?
335 } else {
336 serde_json::json!({})
337 };
338
339 let mut checksums_map = checksums.as_object().unwrap().clone();
340
341 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
343 let prefixed_checksum = format!("sha256:{new_checksum}");
344 checksums_map.insert(filename, Value::String(prefixed_checksum));
345
346 let file = OpenOptions::new()
347 .write(true)
348 .create(true)
349 .truncate(true)
350 .open(checksums_file)?;
351 let writer = BufWriter::new(file);
352 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
353
354 Ok(())
355}
356
357#[cfg(test)]
361mod tests {
362 use std::{
363 fs,
364 io::{BufWriter, Write},
365 net::SocketAddr,
366 sync::{
367 Arc,
368 atomic::{AtomicUsize, Ordering},
369 },
370 };
371
372 use axum::{Router, http::StatusCode, routing::get, serve};
373 use rstest::*;
374 use serde_json::{json, to_writer};
375 use tempfile::TempDir;
376 use tokio::{
377 net::TcpListener,
378 task,
379 time::{Duration, sleep},
380 };
381
382 use super::*;
383
384 fn test_retry_config() -> RetryConfig {
387 RetryConfig {
388 max_retries: 5,
389 initial_delay_ms: 10,
390 max_delay_ms: 50,
391 backoff_factor: 2.0,
392 jitter_ms: 5,
393 operation_timeout_ms: Some(500),
394 immediate_first: false,
395 max_elapsed_ms: Some(2000),
396 }
397 }
398
399 async fn setup_test_server(
400 server_content: Option<String>,
401 status_code: StatusCode,
402 ) -> SocketAddr {
403 let server_content = Arc::new(server_content);
404 let server_content_clone = server_content.clone();
405 let app = Router::new().route(
406 "/testfile.txt",
407 get(move || {
408 let server_content = server_content_clone.clone();
409 async move {
410 let response_body = match &*server_content {
411 Some(content) => content.clone(),
412 None => "File not found".to_string(),
413 };
414 (status_code, response_body)
415 }
416 }),
417 );
418
419 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
420 let addr = listener.local_addr().unwrap();
421 let server = serve(listener, app);
422
423 task::spawn(async move {
424 if let Err(e) = server.await {
425 eprintln!("server error: {e}");
426 }
427 });
428
429 sleep(Duration::from_millis(100)).await;
430
431 addr
432 }
433
434 #[tokio::test]
435 async fn test_file_already_exists() {
436 let temp_dir = TempDir::new().unwrap();
437 let file_path = temp_dir.path().join("testfile.txt");
438 fs::write(&file_path, "Existing file content").unwrap();
439
440 let url = "http://example.com/testfile.txt".to_string();
441 let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
442
443 assert!(result.is_ok());
444 let content = fs::read_to_string(&file_path).unwrap();
445 assert_eq!(content, "Existing file content");
446 }
447
448 #[tokio::test]
449 async fn test_download_file_success() {
450 let temp_dir = TempDir::new().unwrap();
451 let filepath = temp_dir.path().join("testfile.txt");
452 let filepath_clone = filepath.clone();
453
454 let server_content = "Server file content".to_string();
455 let status_code = StatusCode::OK;
456 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
457 let url = format!("http://{addr}/testfile.txt");
458
459 let result = tokio::task::spawn_blocking(move || {
460 ensure_file_exists_or_download_http_with_config(
461 &filepath_clone,
462 &url,
463 None,
464 5,
465 Some(test_retry_config()),
466 Some(0),
467 )
468 })
469 .await
470 .unwrap();
471
472 assert!(result.is_ok());
473 let content = fs::read_to_string(&filepath).unwrap();
474 assert_eq!(content, server_content);
475 }
476
477 #[tokio::test]
478 async fn test_download_file_not_found() {
479 let temp_dir = TempDir::new().unwrap();
480 let file_path = temp_dir.path().join("testfile.txt");
481
482 let server_content = None;
483 let status_code = StatusCode::NOT_FOUND;
484 let addr = setup_test_server(server_content, status_code).await;
485 let url = format!("http://{addr}/testfile.txt");
486
487 let result = tokio::task::spawn_blocking(move || {
488 ensure_file_exists_or_download_http_with_config(
489 &file_path,
490 &url,
491 None,
492 1,
493 Some(test_retry_config()),
494 Some(0),
495 )
496 })
497 .await
498 .unwrap();
499
500 assert!(result.is_err());
501 let err_msg = format!("{}", result.unwrap_err());
502 assert!(
503 err_msg.contains("Client error: HTTP"),
504 "Unexpected error message: {err_msg}"
505 );
506 }
507
508 #[tokio::test]
509 async fn test_network_error() {
510 let temp_dir = TempDir::new().unwrap();
511 let file_path = temp_dir.path().join("testfile.txt");
512
513 let url = "http://127.0.0.1:0/testfile.txt".to_string();
515
516 let result = tokio::task::spawn_blocking(move || {
517 ensure_file_exists_or_download_http_with_config(
518 &file_path,
519 &url,
520 None,
521 2,
522 Some(test_retry_config()),
523 Some(0),
524 )
525 })
526 .await
527 .unwrap();
528
529 assert!(result.is_err());
530 let err_msg = format!("{}", result.unwrap_err());
531 assert!(
532 err_msg.contains("error"),
533 "Unexpected error message: {err_msg}"
534 );
535 }
536
537 #[tokio::test]
538 async fn test_retry_then_success_on_500() {
539 let temp_dir = TempDir::new().unwrap();
540 let filepath = temp_dir.path().join("testfile.txt");
541 let filepath_clone = filepath.clone();
542
543 let counter = Arc::new(AtomicUsize::new(0));
544 let counter_clone = counter.clone();
545
546 let app = Router::new().route(
547 "/testfile.txt",
548 get(move || {
549 let c = counter_clone.clone();
550 async move {
551 let n = c.fetch_add(1, Ordering::SeqCst);
552 if n < 2 {
553 (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
554 } else {
555 (StatusCode::OK, "eventual success")
556 }
557 }
558 }),
559 );
560
561 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
562 let addr = listener.local_addr().unwrap();
563 let server = serve(listener, app);
564 task::spawn(async move {
565 let _ = server.await;
566 });
567 sleep(Duration::from_millis(100)).await;
568
569 let url = format!("http://{addr}/testfile.txt");
570
571 let result = tokio::task::spawn_blocking(move || {
572 ensure_file_exists_or_download_http_with_config(
573 &filepath_clone,
574 &url,
575 None,
576 5,
577 Some(test_retry_config()),
578 Some(0),
579 )
580 })
581 .await
582 .unwrap();
583
584 assert!(result.is_ok());
585 let content = std::fs::read_to_string(&filepath).unwrap();
586 assert_eq!(content, "eventual success");
587 assert!(counter.load(Ordering::SeqCst) >= 2);
588 }
589
590 #[tokio::test]
591 async fn test_retry_then_success_on_429() {
592 let temp_dir = TempDir::new().unwrap();
593 let filepath = temp_dir.path().join("testfile.txt");
594 let filepath_clone = filepath.clone();
595
596 let counter = Arc::new(AtomicUsize::new(0));
597 let counter_clone = counter.clone();
598
599 let app = Router::new().route(
600 "/testfile.txt",
601 get(move || {
602 let c = counter_clone.clone();
603 async move {
604 let n = c.fetch_add(1, Ordering::SeqCst);
605 if n < 1 {
606 (StatusCode::TOO_MANY_REQUESTS, "rate limited")
607 } else {
608 (StatusCode::OK, "ok after retry")
609 }
610 }
611 }),
612 );
613
614 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
615 let addr = listener.local_addr().unwrap();
616 let server = serve(listener, app);
617 task::spawn(async move {
618 let _ = server.await;
619 });
620 sleep(Duration::from_millis(100)).await;
621
622 let url = format!("http://{addr}/testfile.txt");
623
624 let result = tokio::task::spawn_blocking(move || {
625 ensure_file_exists_or_download_http_with_config(
626 &filepath_clone,
627 &url,
628 None,
629 5,
630 Some(test_retry_config()),
631 Some(0),
632 )
633 })
634 .await
635 .unwrap();
636
637 assert!(result.is_ok());
638 let content = std::fs::read_to_string(&filepath).unwrap();
639 assert_eq!(content, "ok after retry");
640 assert!(counter.load(Ordering::SeqCst) >= 2);
641 }
642
643 #[tokio::test]
644 async fn test_no_retry_on_404() {
645 let temp_dir = TempDir::new().unwrap();
646 let filepath = temp_dir.path().join("testfile.txt");
647 let filepath_clone = filepath.clone();
648
649 let counter = Arc::new(AtomicUsize::new(0));
650 let counter_clone = counter.clone();
651
652 let app = Router::new().route(
653 "/testfile.txt",
654 get(move || {
655 let c = counter_clone.clone();
656 async move {
657 c.fetch_add(1, Ordering::SeqCst);
658 (StatusCode::NOT_FOUND, "missing")
659 }
660 }),
661 );
662
663 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
664 let addr = listener.local_addr().unwrap();
665 let server = serve(listener, app);
666 task::spawn(async move {
667 let _ = server.await;
668 });
669 sleep(Duration::from_millis(100)).await;
670
671 let url = format!("http://{addr}/testfile.txt");
672
673 let result = tokio::task::spawn_blocking(move || {
674 ensure_file_exists_or_download_http_with_config(
675 &filepath_clone,
676 &url,
677 None,
678 5,
679 Some(test_retry_config()),
680 Some(0),
681 )
682 })
683 .await
684 .unwrap();
685
686 assert!(result.is_err());
687 assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
688 }
689
690 #[rstest]
691 #[allow(clippy::panic_in_result_fn)]
692 fn test_calculate_sha256() -> anyhow::Result<()> {
693 let temp_dir = TempDir::new()?;
694 let test_file_path = temp_dir.path().join("test_file.txt");
695 let mut test_file = File::create(&test_file_path)?;
696 let content = b"Hello, world!";
697 test_file.write_all(content)?;
698
699 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
700 let calculated_hash = calculate_sha256(&test_file_path)?;
701
702 assert_eq!(calculated_hash, expected_hash);
703 Ok(())
704 }
705
706 #[rstest]
707 #[allow(clippy::panic_in_result_fn)]
708 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
709 let temp_dir = TempDir::new()?;
710 let test_file_path = temp_dir.path().join("test_file.txt");
711 let mut test_file = File::create(&test_file_path)?;
712 let content = b"Hello, world!";
713 test_file.write_all(content)?;
714
715 let calculated_checksum = calculate_sha256(&test_file_path)?;
716
717 let checksums_path = temp_dir.path().join("checksums.json");
719 let checksums_data = json!({
720 "test_file.txt": format!("sha256:{}", calculated_checksum)
721 });
722 let checksums_file = File::create(&checksums_path)?;
723 let writer = BufWriter::new(checksums_file);
724 to_writer(writer, &checksums_data)?;
725
726 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
727 assert!(is_valid, "The checksum should be valid");
728 Ok(())
729 }
730}