1use std::{
17 cmp,
18 fmt::Display,
19 fs::{File, OpenOptions},
20 io::{BufReader, BufWriter, Read, copy},
21 path::Path,
22 thread::sleep,
23 time::{Duration, Instant},
24};
25
26use aws_lc_rs::digest::{self, Context};
27use nautilus_network::retry::RetryConfig;
28use rand::{Rng, rng};
29use reqwest::blocking::Client;
30use serde_json::Value;
31
32#[derive(Debug)]
33enum DownloadError {
34 Retryable(String),
35 NonRetryable(String),
36}
37
38impl Display for DownloadError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
42 Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
43 }
44 }
45}
46
47impl std::error::Error for DownloadError {}
48
49fn execute_with_retry_blocking<T, E, F>(
50 config: &RetryConfig,
51 mut op: F,
52 should_retry: impl Fn(&E) -> bool,
53) -> Result<T, E>
54where
55 E: std::error::Error,
56 F: FnMut() -> Result<T, E>,
57{
58 let start = Instant::now();
59 let mut delay = Duration::from_millis(config.initial_delay_ms);
60
61 for attempt in 0..=config.max_retries {
62 if attempt > 0 && !config.immediate_first {
63 let jitter = rng().random_range(0..=config.jitter_ms);
64 let sleep_for = delay + Duration::from_millis(jitter);
65 sleep(sleep_for);
66 let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
67 delay = cmp::min(
68 Duration::from_millis(next),
69 Duration::from_millis(config.max_delay_ms),
70 );
71 }
72
73 if let Some(max_total) = config.max_elapsed_ms
74 && start.elapsed() >= Duration::from_millis(max_total)
75 {
76 break;
77 }
78
79 match op() {
80 Ok(v) => return Ok(v),
81 Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
82 Err(e) => return Err(e),
83 }
84 }
85
86 op()
87}
88
89pub fn ensure_file_exists_or_download_http(
109 filepath: &Path,
110 url: &str,
111 checksums: Option<&Path>,
112 timeout_secs: Option<u64>,
113) -> anyhow::Result<()> {
114 ensure_file_exists_or_download_http_with_config(
115 filepath,
116 url,
117 checksums,
118 timeout_secs.unwrap_or(30),
119 None,
120 None,
121 )
122}
123
124pub fn ensure_file_exists_or_download_http_with_timeout(
133 filepath: &Path,
134 url: &str,
135 checksums: Option<&Path>,
136 timeout_secs: u64,
137) -> anyhow::Result<()> {
138 ensure_file_exists_or_download_http_with_config(
139 filepath,
140 url,
141 checksums,
142 timeout_secs,
143 None,
144 None,
145 )
146}
147
148pub fn ensure_file_exists_or_download_http_with_config(
167 filepath: &Path,
168 url: &str,
169 checksums: Option<&Path>,
170 timeout_secs: u64,
171 retry_config: Option<RetryConfig>,
172 initial_jitter_ms: Option<u64>,
173) -> anyhow::Result<()> {
174 if filepath.exists() {
175 println!("File already exists: {filepath:?}");
176
177 if let Some(checksums_file) = checksums {
178 if verify_sha256_checksum(filepath, checksums_file)? {
179 println!("File is valid");
180 return Ok(());
181 } else {
182 let new_checksum = calculate_sha256(filepath)?;
183 println!("Adding checksum for existing file: {new_checksum}");
184 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
185 return Ok(());
186 }
187 }
188 return Ok(());
189 }
190
191 if let Some(jitter_ms) = initial_jitter_ms {
194 if jitter_ms > 0 {
195 sleep(Duration::from_millis(jitter_ms));
196 }
197 } else {
198 let jitter_delay = {
199 let mut r = rng();
200 Duration::from_millis(r.random_range(100..=600))
201 };
202 sleep(jitter_delay);
203 }
204
205 download_file(filepath, url, timeout_secs, retry_config)?;
206
207 if let Some(checksums_file) = checksums {
208 let new_checksum = calculate_sha256(filepath)?;
209 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
210 }
211
212 Ok(())
213}
214
215fn download_file(
216 filepath: &Path,
217 url: &str,
218 timeout_secs: u64,
219 retry_config: Option<RetryConfig>,
220) -> anyhow::Result<()> {
221 #[cfg(not(test))]
225 if !url.starts_with("https://") {
226 anyhow::bail!("URL must use HTTPS protocol for security: {url}");
227 }
228
229 println!("Downloading file from {url} to {filepath:?}");
230
231 if let Some(parent) = filepath.parent() {
232 std::fs::create_dir_all(parent)?;
233 }
234
235 let client = Client::builder()
236 .timeout(Duration::from_secs(timeout_secs))
237 .build()?;
238
239 let cfg = if let Some(config) = retry_config {
240 config
241 } else {
242 let max_retries = 5u32;
244 let op_timeout_ms = timeout_secs.saturating_mul(1000);
245 let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
248 RetryConfig {
249 max_retries,
250 initial_delay_ms: 1_000,
251 max_delay_ms: 10_000,
252 backoff_factor: 2.0,
253 jitter_ms: 1_000,
254 operation_timeout_ms: Some(per_attempt_ms),
255 immediate_first: false,
256 max_elapsed_ms: Some(op_timeout_ms),
257 }
258 };
259
260 let op = || -> Result<(), DownloadError> {
261 match client.get(url).send() {
262 Ok(mut response) => {
263 let status = response.status();
264 if status.is_success() {
265 let mut out = File::create(filepath)
266 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
267 copy(&mut response, &mut out)
269 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
270 println!("File downloaded to {filepath:?}");
271 Ok(())
272 } else if status.is_server_error()
273 || status.as_u16() == 429
274 || status.as_u16() == 408
275 {
276 println!("HTTP error {status}, retrying...");
277 Err(DownloadError::Retryable(format!("HTTP {status}")))
278 } else {
279 Err(DownloadError::NonRetryable(format!(
281 "Client error: HTTP {status}"
282 )))
283 }
284 }
285 Err(e) => {
286 println!("Request failed: {e}");
287 Err(DownloadError::Retryable(e.to_string()))
288 }
289 }
290 };
291
292 let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
293
294 execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
295}
296
297fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
298 let mut file = File::open(filepath)?;
299 let mut ctx = Context::new(&digest::SHA256);
300 let mut buffer = [0u8; 4096];
301
302 loop {
303 let count = file.read(&mut buffer)?;
304 if count == 0 {
305 break;
306 }
307 ctx.update(&buffer[..count]);
308 }
309
310 let digest = ctx.finish();
311 Ok(hex::encode(digest.as_ref()))
312}
313
314fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
315 let file = File::open(checksums)?;
316 let reader = BufReader::new(file);
317 let checksums: Value = serde_json::from_reader(reader)?;
318
319 let filename = filepath.file_name().unwrap().to_str().unwrap();
320 if let Some(expected_checksum) = checksums.get(filename) {
321 let expected_checksum_str = expected_checksum.as_str().unwrap();
322 let expected_hash = expected_checksum_str
323 .strip_prefix("sha256:")
324 .unwrap_or(expected_checksum_str);
325 let calculated_checksum = calculate_sha256(filepath)?;
326 if expected_hash == calculated_checksum {
327 return Ok(true);
328 }
329 }
330
331 Ok(false)
332}
333
334fn update_sha256_checksums(
335 filepath: &Path,
336 checksums_file: &Path,
337 new_checksum: &str,
338) -> anyhow::Result<()> {
339 let checksums: Value = if checksums_file.exists() {
340 let file = File::open(checksums_file)?;
341 let reader = BufReader::new(file);
342 serde_json::from_reader(reader)?
343 } else {
344 serde_json::json!({})
345 };
346
347 let mut checksums_map = checksums.as_object().unwrap().clone();
348
349 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
351 let prefixed_checksum = format!("sha256:{new_checksum}");
352 checksums_map.insert(filename, Value::String(prefixed_checksum));
353
354 let file = OpenOptions::new()
355 .write(true)
356 .create(true)
357 .truncate(true)
358 .open(checksums_file)?;
359 let writer = BufWriter::new(file);
360 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
361
362 Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367 use std::{
368 fs,
369 io::{BufWriter, Write},
370 net::SocketAddr,
371 sync::{
372 Arc,
373 atomic::{AtomicUsize, Ordering},
374 },
375 };
376
377 use axum::{Router, http::StatusCode, routing::get, serve};
378 use rstest::*;
379 use serde_json::{json, to_writer};
380 use tempfile::TempDir;
381 use tokio::{
382 net::TcpListener,
383 task,
384 time::{Duration, sleep},
385 };
386
387 use super::*;
388
389 fn test_retry_config() -> RetryConfig {
392 RetryConfig {
393 max_retries: 5,
394 initial_delay_ms: 10,
395 max_delay_ms: 50,
396 backoff_factor: 2.0,
397 jitter_ms: 5,
398 operation_timeout_ms: Some(500),
399 immediate_first: false,
400 max_elapsed_ms: Some(2000),
401 }
402 }
403
404 async fn setup_test_server(
405 server_content: Option<String>,
406 status_code: StatusCode,
407 ) -> SocketAddr {
408 let server_content = Arc::new(server_content);
409 let server_content_clone = server_content.clone();
410 let app = Router::new().route(
411 "/testfile.txt",
412 get(move || {
413 let server_content = server_content_clone.clone();
414 async move {
415 let response_body = match &*server_content {
416 Some(content) => content.clone(),
417 None => "File not found".to_string(),
418 };
419 (status_code, response_body)
420 }
421 }),
422 );
423
424 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
425 let addr = listener.local_addr().unwrap();
426 let server = serve(listener, app);
427
428 task::spawn(async move {
429 if let Err(e) = server.await {
430 eprintln!("server error: {e}");
431 }
432 });
433
434 sleep(Duration::from_millis(100)).await;
435
436 addr
437 }
438
439 #[tokio::test]
440 async fn test_file_already_exists() {
441 let temp_dir = TempDir::new().unwrap();
442 let file_path = temp_dir.path().join("testfile.txt");
443 fs::write(&file_path, "Existing file content").unwrap();
444
445 let url = "http://example.com/testfile.txt".to_string();
446 let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
447
448 assert!(result.is_ok());
449 let content = fs::read_to_string(&file_path).unwrap();
450 assert_eq!(content, "Existing file content");
451 }
452
453 #[tokio::test]
454 async fn test_download_file_success() {
455 let temp_dir = TempDir::new().unwrap();
456 let filepath = temp_dir.path().join("testfile.txt");
457 let filepath_clone = filepath.clone();
458
459 let server_content = "Server file content".to_string();
460 let status_code = StatusCode::OK;
461 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
462 let url = format!("http://{addr}/testfile.txt");
463
464 let result = tokio::task::spawn_blocking(move || {
465 ensure_file_exists_or_download_http_with_config(
466 &filepath_clone,
467 &url,
468 None,
469 5,
470 Some(test_retry_config()),
471 Some(0),
472 )
473 })
474 .await
475 .unwrap();
476
477 assert!(result.is_ok());
478 let content = fs::read_to_string(&filepath).unwrap();
479 assert_eq!(content, server_content);
480 }
481
482 #[tokio::test]
483 async fn test_download_file_not_found() {
484 let temp_dir = TempDir::new().unwrap();
485 let file_path = temp_dir.path().join("testfile.txt");
486
487 let server_content = None;
488 let status_code = StatusCode::NOT_FOUND;
489 let addr = setup_test_server(server_content, status_code).await;
490 let url = format!("http://{addr}/testfile.txt");
491
492 let result = tokio::task::spawn_blocking(move || {
493 ensure_file_exists_or_download_http_with_config(
494 &file_path,
495 &url,
496 None,
497 1,
498 Some(test_retry_config()),
499 Some(0),
500 )
501 })
502 .await
503 .unwrap();
504
505 assert!(result.is_err());
506 let err_msg = format!("{}", result.unwrap_err());
507 assert!(
508 err_msg.contains("Client error: HTTP"),
509 "Unexpected error message: {err_msg}"
510 );
511 }
512
513 #[tokio::test]
514 async fn test_network_error() {
515 let temp_dir = TempDir::new().unwrap();
516 let file_path = temp_dir.path().join("testfile.txt");
517
518 let url = "http://127.0.0.1:0/testfile.txt".to_string();
520
521 let result = tokio::task::spawn_blocking(move || {
522 ensure_file_exists_or_download_http_with_config(
523 &file_path,
524 &url,
525 None,
526 2,
527 Some(test_retry_config()),
528 Some(0),
529 )
530 })
531 .await
532 .unwrap();
533
534 assert!(result.is_err());
535 let err_msg = format!("{}", result.unwrap_err());
536 assert!(
537 err_msg.contains("error"),
538 "Unexpected error message: {err_msg}"
539 );
540 }
541
542 #[tokio::test]
543 async fn test_retry_then_success_on_500() {
544 let temp_dir = TempDir::new().unwrap();
545 let filepath = temp_dir.path().join("testfile.txt");
546 let filepath_clone = filepath.clone();
547
548 let counter = Arc::new(AtomicUsize::new(0));
549 let counter_clone = counter.clone();
550
551 let app = Router::new().route(
552 "/testfile.txt",
553 get(move || {
554 let c = counter_clone.clone();
555 async move {
556 let n = c.fetch_add(1, Ordering::SeqCst);
557 if n < 2 {
558 (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
559 } else {
560 (StatusCode::OK, "eventual success")
561 }
562 }
563 }),
564 );
565
566 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
567 let addr = listener.local_addr().unwrap();
568 let server = serve(listener, app);
569 task::spawn(async move {
570 let _ = server.await;
571 });
572 sleep(Duration::from_millis(100)).await;
573
574 let url = format!("http://{addr}/testfile.txt");
575
576 let result = tokio::task::spawn_blocking(move || {
577 ensure_file_exists_or_download_http_with_config(
578 &filepath_clone,
579 &url,
580 None,
581 5,
582 Some(test_retry_config()),
583 Some(0),
584 )
585 })
586 .await
587 .unwrap();
588
589 assert!(result.is_ok());
590 let content = std::fs::read_to_string(&filepath).unwrap();
591 assert_eq!(content, "eventual success");
592 assert!(counter.load(Ordering::SeqCst) >= 2);
593 }
594
595 #[tokio::test]
596 async fn test_retry_then_success_on_429() {
597 let temp_dir = TempDir::new().unwrap();
598 let filepath = temp_dir.path().join("testfile.txt");
599 let filepath_clone = filepath.clone();
600
601 let counter = Arc::new(AtomicUsize::new(0));
602 let counter_clone = counter.clone();
603
604 let app = Router::new().route(
605 "/testfile.txt",
606 get(move || {
607 let c = counter_clone.clone();
608 async move {
609 let n = c.fetch_add(1, Ordering::SeqCst);
610 if n < 1 {
611 (StatusCode::TOO_MANY_REQUESTS, "rate limited")
612 } else {
613 (StatusCode::OK, "ok after retry")
614 }
615 }
616 }),
617 );
618
619 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
620 let addr = listener.local_addr().unwrap();
621 let server = serve(listener, app);
622 task::spawn(async move {
623 let _ = server.await;
624 });
625 sleep(Duration::from_millis(100)).await;
626
627 let url = format!("http://{addr}/testfile.txt");
628
629 let result = tokio::task::spawn_blocking(move || {
630 ensure_file_exists_or_download_http_with_config(
631 &filepath_clone,
632 &url,
633 None,
634 5,
635 Some(test_retry_config()),
636 Some(0),
637 )
638 })
639 .await
640 .unwrap();
641
642 assert!(result.is_ok());
643 let content = std::fs::read_to_string(&filepath).unwrap();
644 assert_eq!(content, "ok after retry");
645 assert!(counter.load(Ordering::SeqCst) >= 2);
646 }
647
648 #[tokio::test]
649 async fn test_no_retry_on_404() {
650 let temp_dir = TempDir::new().unwrap();
651 let filepath = temp_dir.path().join("testfile.txt");
652 let filepath_clone = filepath.clone();
653
654 let counter = Arc::new(AtomicUsize::new(0));
655 let counter_clone = counter.clone();
656
657 let app = Router::new().route(
658 "/testfile.txt",
659 get(move || {
660 let c = counter_clone.clone();
661 async move {
662 c.fetch_add(1, Ordering::SeqCst);
663 (StatusCode::NOT_FOUND, "missing")
664 }
665 }),
666 );
667
668 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
669 let addr = listener.local_addr().unwrap();
670 let server = serve(listener, app);
671 task::spawn(async move {
672 let _ = server.await;
673 });
674 sleep(Duration::from_millis(100)).await;
675
676 let url = format!("http://{addr}/testfile.txt");
677
678 let result = tokio::task::spawn_blocking(move || {
679 ensure_file_exists_or_download_http_with_config(
680 &filepath_clone,
681 &url,
682 None,
683 5,
684 Some(test_retry_config()),
685 Some(0),
686 )
687 })
688 .await
689 .unwrap();
690
691 assert!(result.is_err());
692 assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
693 }
694
695 #[rstest]
696 #[allow(clippy::panic_in_result_fn)]
697 fn test_calculate_sha256() -> anyhow::Result<()> {
698 let temp_dir = TempDir::new()?;
699 let test_file_path = temp_dir.path().join("test_file.txt");
700 let mut test_file = File::create(&test_file_path)?;
701 let content = b"Hello, world!";
702 test_file.write_all(content)?;
703
704 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
705 let calculated_hash = calculate_sha256(&test_file_path)?;
706
707 assert_eq!(calculated_hash, expected_hash);
708 Ok(())
709 }
710
711 #[rstest]
712 #[allow(clippy::panic_in_result_fn)]
713 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
714 let temp_dir = TempDir::new()?;
715 let test_file_path = temp_dir.path().join("test_file.txt");
716 let mut test_file = File::create(&test_file_path)?;
717 let content = b"Hello, world!";
718 test_file.write_all(content)?;
719
720 let calculated_checksum = calculate_sha256(&test_file_path)?;
721
722 let checksums_path = temp_dir.path().join("checksums.json");
724 let checksums_data = json!({
725 "test_file.txt": format!("sha256:{}", calculated_checksum)
726 });
727 let checksums_file = File::create(&checksums_path)?;
728 let writer = BufWriter::new(checksums_file);
729 to_writer(writer, &checksums_data)?;
730
731 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
732 assert!(is_valid, "The checksum should be valid");
733 Ok(())
734 }
735}