1use std::{
17 cmp,
18 fs::{File, OpenOptions},
19 io::{BufReader, BufWriter, Read, copy},
20 path::Path,
21 thread::sleep,
22 time::{Duration, Instant},
23};
24
25use aws_lc_rs::digest::{self, Context};
26use nautilus_network::retry::RetryConfig;
27use rand::{Rng, rng};
28use reqwest::blocking::Client;
29use serde_json::Value;
30
31#[derive(Debug)]
32enum DownloadError {
33 Retryable(String),
34 NonRetryable(String),
35}
36
37impl std::fmt::Display for DownloadError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
41 Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
42 }
43 }
44}
45
46impl std::error::Error for DownloadError {}
47
48fn execute_with_retry_blocking<T, E, F>(
49 config: &RetryConfig,
50 mut op: F,
51 should_retry: impl Fn(&E) -> bool,
52) -> Result<T, E>
53where
54 E: std::error::Error,
55 F: FnMut() -> Result<T, E>,
56{
57 let start = Instant::now();
58 let mut delay = Duration::from_millis(config.initial_delay_ms);
59
60 for attempt in 0..=config.max_retries {
61 if attempt > 0 && !config.immediate_first {
62 let jitter = rng().random_range(0..=config.jitter_ms);
63 let sleep_for = delay + Duration::from_millis(jitter);
64 sleep(sleep_for);
65 let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
66 delay = cmp::min(
67 Duration::from_millis(next),
68 Duration::from_millis(config.max_delay_ms),
69 );
70 }
71
72 if let Some(max_total) = config.max_elapsed_ms
73 && start.elapsed() >= Duration::from_millis(max_total)
74 {
75 break;
76 }
77
78 match op() {
79 Ok(v) => return Ok(v),
80 Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
81 Err(e) => return Err(e),
82 }
83 }
84
85 op()
86}
87
88pub fn ensure_file_exists_or_download_http(
108 filepath: &Path,
109 url: &str,
110 checksums: Option<&Path>,
111 timeout_secs: Option<u64>,
112) -> anyhow::Result<()> {
113 ensure_file_exists_or_download_http_with_config(
114 filepath,
115 url,
116 checksums,
117 timeout_secs.unwrap_or(30),
118 None,
119 None,
120 )
121}
122
123pub fn ensure_file_exists_or_download_http_with_timeout(
132 filepath: &Path,
133 url: &str,
134 checksums: Option<&Path>,
135 timeout_secs: u64,
136) -> anyhow::Result<()> {
137 ensure_file_exists_or_download_http_with_config(
138 filepath,
139 url,
140 checksums,
141 timeout_secs,
142 None,
143 None,
144 )
145}
146
147pub fn ensure_file_exists_or_download_http_with_config(
166 filepath: &Path,
167 url: &str,
168 checksums: Option<&Path>,
169 timeout_secs: u64,
170 retry_config: Option<RetryConfig>,
171 initial_jitter_ms: Option<u64>,
172) -> anyhow::Result<()> {
173 if filepath.exists() {
174 println!("File already exists: {filepath:?}");
175
176 if let Some(checksums_file) = checksums {
177 if verify_sha256_checksum(filepath, checksums_file)? {
178 println!("File is valid");
179 return Ok(());
180 } else {
181 let new_checksum = calculate_sha256(filepath)?;
182 println!("Adding checksum for existing file: {new_checksum}");
183 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
184 return Ok(());
185 }
186 }
187 return Ok(());
188 }
189
190 if let Some(jitter_ms) = initial_jitter_ms {
193 if jitter_ms > 0 {
194 sleep(Duration::from_millis(jitter_ms));
195 }
196 } else {
197 let jitter_delay = {
198 let mut r = rng();
199 Duration::from_millis(r.random_range(100..=600))
200 };
201 sleep(jitter_delay);
202 }
203
204 download_file(filepath, url, timeout_secs, retry_config)?;
205
206 if let Some(checksums_file) = checksums {
207 let new_checksum = calculate_sha256(filepath)?;
208 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
209 }
210
211 Ok(())
212}
213
214fn download_file(
215 filepath: &Path,
216 url: &str,
217 timeout_secs: u64,
218 retry_config: Option<RetryConfig>,
219) -> anyhow::Result<()> {
220 println!("Downloading file from {url} to {filepath:?}");
221
222 if let Some(parent) = filepath.parent() {
223 std::fs::create_dir_all(parent)?;
224 }
225
226 let client = Client::builder()
227 .timeout(Duration::from_secs(timeout_secs))
228 .build()?;
229
230 let cfg = if let Some(config) = retry_config {
231 config
232 } else {
233 let max_retries = 5u32;
235 let op_timeout_ms = timeout_secs.saturating_mul(1000);
236 let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
239 RetryConfig {
240 max_retries,
241 initial_delay_ms: 1_000,
242 max_delay_ms: 10_000,
243 backoff_factor: 2.0,
244 jitter_ms: 1_000,
245 operation_timeout_ms: Some(per_attempt_ms),
246 immediate_first: false,
247 max_elapsed_ms: Some(op_timeout_ms),
248 }
249 };
250
251 let op = || -> Result<(), DownloadError> {
252 match client.get(url).send() {
253 Ok(mut response) => {
254 let status = response.status();
255 if status.is_success() {
256 let mut out = File::create(filepath)
257 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
258 copy(&mut response, &mut out)
260 .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
261 println!("File downloaded to {filepath:?}");
262 Ok(())
263 } else if status.is_server_error()
264 || status.as_u16() == 429
265 || status.as_u16() == 408
266 {
267 println!("HTTP error {status}, retrying...");
268 Err(DownloadError::Retryable(format!("HTTP {status}")))
269 } else {
270 Err(DownloadError::NonRetryable(format!(
272 "Client error: HTTP {status}"
273 )))
274 }
275 }
276 Err(e) => {
277 println!("Request failed: {e}");
278 Err(DownloadError::Retryable(e.to_string()))
279 }
280 }
281 };
282
283 let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
284
285 execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
286}
287
288fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
289 let mut file = File::open(filepath)?;
290 let mut ctx = Context::new(&digest::SHA256);
291 let mut buffer = [0u8; 4096];
292
293 loop {
294 let count = file.read(&mut buffer)?;
295 if count == 0 {
296 break;
297 }
298 ctx.update(&buffer[..count]);
299 }
300
301 let digest = ctx.finish();
302 Ok(hex::encode(digest.as_ref()))
303}
304
305fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
306 let file = File::open(checksums)?;
307 let reader = BufReader::new(file);
308 let checksums: Value = serde_json::from_reader(reader)?;
309
310 let filename = filepath.file_name().unwrap().to_str().unwrap();
311 if let Some(expected_checksum) = checksums.get(filename) {
312 let expected_checksum_str = expected_checksum.as_str().unwrap();
313 let expected_hash = expected_checksum_str
314 .strip_prefix("sha256:")
315 .unwrap_or(expected_checksum_str);
316 let calculated_checksum = calculate_sha256(filepath)?;
317 if expected_hash == calculated_checksum {
318 return Ok(true);
319 }
320 }
321
322 Ok(false)
323}
324
325fn update_sha256_checksums(
326 filepath: &Path,
327 checksums_file: &Path,
328 new_checksum: &str,
329) -> anyhow::Result<()> {
330 let checksums: Value = if checksums_file.exists() {
331 let file = File::open(checksums_file)?;
332 let reader = BufReader::new(file);
333 serde_json::from_reader(reader)?
334 } else {
335 serde_json::json!({})
336 };
337
338 let mut checksums_map = checksums.as_object().unwrap().clone();
339
340 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
342 let prefixed_checksum = format!("sha256:{new_checksum}");
343 checksums_map.insert(filename, Value::String(prefixed_checksum));
344
345 let file = OpenOptions::new()
346 .write(true)
347 .create(true)
348 .truncate(true)
349 .open(checksums_file)?;
350 let writer = BufWriter::new(file);
351 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
352
353 Ok(())
354}
355
356#[cfg(test)]
360mod tests {
361 use std::{
362 fs,
363 io::{BufWriter, Write},
364 net::SocketAddr,
365 sync::{
366 Arc,
367 atomic::{AtomicUsize, Ordering},
368 },
369 };
370
371 use axum::{Router, http::StatusCode, routing::get, serve};
372 use rstest::*;
373 use serde_json::{json, to_writer};
374 use tempfile::TempDir;
375 use tokio::{
376 net::TcpListener,
377 task,
378 time::{Duration, sleep},
379 };
380
381 use super::*;
382
383 fn test_retry_config() -> RetryConfig {
386 RetryConfig {
387 max_retries: 5,
388 initial_delay_ms: 10,
389 max_delay_ms: 50,
390 backoff_factor: 2.0,
391 jitter_ms: 5,
392 operation_timeout_ms: Some(500),
393 immediate_first: false,
394 max_elapsed_ms: Some(2000),
395 }
396 }
397
398 async fn setup_test_server(
399 server_content: Option<String>,
400 status_code: StatusCode,
401 ) -> SocketAddr {
402 let server_content = Arc::new(server_content);
403 let server_content_clone = server_content.clone();
404 let app = Router::new().route(
405 "/testfile.txt",
406 get(move || {
407 let server_content = server_content_clone.clone();
408 async move {
409 let response_body = match &*server_content {
410 Some(content) => content.clone(),
411 None => "File not found".to_string(),
412 };
413 (status_code, response_body)
414 }
415 }),
416 );
417
418 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
419 let addr = listener.local_addr().unwrap();
420 let server = serve(listener, app);
421
422 task::spawn(async move {
423 if let Err(e) = server.await {
424 eprintln!("server error: {e}");
425 }
426 });
427
428 sleep(Duration::from_millis(100)).await;
429
430 addr
431 }
432
433 #[tokio::test]
434 async fn test_file_already_exists() {
435 let temp_dir = TempDir::new().unwrap();
436 let file_path = temp_dir.path().join("testfile.txt");
437 fs::write(&file_path, "Existing file content").unwrap();
438
439 let url = "http://example.com/testfile.txt".to_string();
440 let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
441
442 assert!(result.is_ok());
443 let content = fs::read_to_string(&file_path).unwrap();
444 assert_eq!(content, "Existing file content");
445 }
446
447 #[tokio::test]
448 async fn test_download_file_success() {
449 let temp_dir = TempDir::new().unwrap();
450 let filepath = temp_dir.path().join("testfile.txt");
451 let filepath_clone = filepath.clone();
452
453 let server_content = "Server file content".to_string();
454 let status_code = StatusCode::OK;
455 let addr = setup_test_server(Some(server_content.clone()), status_code).await;
456 let url = format!("http://{addr}/testfile.txt");
457
458 let result = tokio::task::spawn_blocking(move || {
459 ensure_file_exists_or_download_http_with_config(
460 &filepath_clone,
461 &url,
462 None,
463 5,
464 Some(test_retry_config()),
465 Some(0),
466 )
467 })
468 .await
469 .unwrap();
470
471 assert!(result.is_ok());
472 let content = fs::read_to_string(&filepath).unwrap();
473 assert_eq!(content, server_content);
474 }
475
476 #[tokio::test]
477 async fn test_download_file_not_found() {
478 let temp_dir = TempDir::new().unwrap();
479 let file_path = temp_dir.path().join("testfile.txt");
480
481 let server_content = None;
482 let status_code = StatusCode::NOT_FOUND;
483 let addr = setup_test_server(server_content, status_code).await;
484 let url = format!("http://{addr}/testfile.txt");
485
486 let result = tokio::task::spawn_blocking(move || {
487 ensure_file_exists_or_download_http_with_config(
488 &file_path,
489 &url,
490 None,
491 1,
492 Some(test_retry_config()),
493 Some(0),
494 )
495 })
496 .await
497 .unwrap();
498
499 assert!(result.is_err());
500 let err_msg = format!("{}", result.unwrap_err());
501 assert!(
502 err_msg.contains("Client error: HTTP"),
503 "Unexpected error message: {err_msg}"
504 );
505 }
506
507 #[tokio::test]
508 async fn test_network_error() {
509 let temp_dir = TempDir::new().unwrap();
510 let file_path = temp_dir.path().join("testfile.txt");
511
512 let url = "http://127.0.0.1:0/testfile.txt".to_string();
514
515 let result = tokio::task::spawn_blocking(move || {
516 ensure_file_exists_or_download_http_with_config(
517 &file_path,
518 &url,
519 None,
520 2,
521 Some(test_retry_config()),
522 Some(0),
523 )
524 })
525 .await
526 .unwrap();
527
528 assert!(result.is_err());
529 let err_msg = format!("{}", result.unwrap_err());
530 assert!(
531 err_msg.contains("error"),
532 "Unexpected error message: {err_msg}"
533 );
534 }
535
536 #[tokio::test]
537 async fn test_retry_then_success_on_500() {
538 let temp_dir = TempDir::new().unwrap();
539 let filepath = temp_dir.path().join("testfile.txt");
540 let filepath_clone = filepath.clone();
541
542 let counter = Arc::new(AtomicUsize::new(0));
543 let counter_clone = counter.clone();
544
545 let app = Router::new().route(
546 "/testfile.txt",
547 get(move || {
548 let c = counter_clone.clone();
549 async move {
550 let n = c.fetch_add(1, Ordering::SeqCst);
551 if n < 2 {
552 (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
553 } else {
554 (StatusCode::OK, "eventual success")
555 }
556 }
557 }),
558 );
559
560 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
561 let addr = listener.local_addr().unwrap();
562 let server = serve(listener, app);
563 task::spawn(async move {
564 let _ = server.await;
565 });
566 sleep(Duration::from_millis(100)).await;
567
568 let url = format!("http://{addr}/testfile.txt");
569 let result = tokio::task::spawn_blocking(move || {
570 ensure_file_exists_or_download_http_with_config(
571 &filepath_clone,
572 &url,
573 None,
574 5,
575 Some(test_retry_config()),
576 Some(0),
577 )
578 })
579 .await
580 .unwrap();
581
582 assert!(result.is_ok());
583 let content = std::fs::read_to_string(&filepath).unwrap();
584 assert_eq!(content, "eventual success");
585 assert!(counter.load(Ordering::SeqCst) >= 2);
586 }
587
588 #[tokio::test]
589 async fn test_retry_then_success_on_429() {
590 let temp_dir = TempDir::new().unwrap();
591 let filepath = temp_dir.path().join("testfile.txt");
592 let filepath_clone = filepath.clone();
593
594 let counter = Arc::new(AtomicUsize::new(0));
595 let counter_clone = counter.clone();
596
597 let app = Router::new().route(
598 "/testfile.txt",
599 get(move || {
600 let c = counter_clone.clone();
601 async move {
602 let n = c.fetch_add(1, Ordering::SeqCst);
603 if n < 1 {
604 (StatusCode::TOO_MANY_REQUESTS, "rate limited")
605 } else {
606 (StatusCode::OK, "ok after retry")
607 }
608 }
609 }),
610 );
611
612 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
613 let addr = listener.local_addr().unwrap();
614 let server = serve(listener, app);
615 task::spawn(async move {
616 let _ = server.await;
617 });
618 sleep(Duration::from_millis(100)).await;
619
620 let url = format!("http://{addr}/testfile.txt");
621 let result = tokio::task::spawn_blocking(move || {
622 ensure_file_exists_or_download_http_with_config(
623 &filepath_clone,
624 &url,
625 None,
626 5,
627 Some(test_retry_config()),
628 Some(0),
629 )
630 })
631 .await
632 .unwrap();
633
634 assert!(result.is_ok());
635 let content = std::fs::read_to_string(&filepath).unwrap();
636 assert_eq!(content, "ok after retry");
637 assert!(counter.load(Ordering::SeqCst) >= 2);
638 }
639
640 #[tokio::test]
641 async fn test_no_retry_on_404() {
642 let temp_dir = TempDir::new().unwrap();
643 let filepath = temp_dir.path().join("testfile.txt");
644 let filepath_clone = filepath.clone();
645
646 let counter = Arc::new(AtomicUsize::new(0));
647 let counter_clone = counter.clone();
648
649 let app = Router::new().route(
650 "/testfile.txt",
651 get(move || {
652 let c = counter_clone.clone();
653 async move {
654 c.fetch_add(1, Ordering::SeqCst);
655 (StatusCode::NOT_FOUND, "missing")
656 }
657 }),
658 );
659
660 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
661 let addr = listener.local_addr().unwrap();
662 let server = serve(listener, app);
663 task::spawn(async move {
664 let _ = server.await;
665 });
666 sleep(Duration::from_millis(100)).await;
667
668 let url = format!("http://{addr}/testfile.txt");
669 let result = tokio::task::spawn_blocking(move || {
670 ensure_file_exists_or_download_http_with_config(
671 &filepath_clone,
672 &url,
673 None,
674 5,
675 Some(test_retry_config()),
676 Some(0),
677 )
678 })
679 .await
680 .unwrap();
681
682 assert!(result.is_err());
683 assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
684 }
685
686 #[rstest]
687 fn test_calculate_sha256() -> anyhow::Result<()> {
688 let temp_dir = TempDir::new()?;
689 let test_file_path = temp_dir.path().join("test_file.txt");
690 let mut test_file = File::create(&test_file_path)?;
691 let content = b"Hello, world!";
692 test_file.write_all(content)?;
693
694 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
695 let calculated_hash = calculate_sha256(&test_file_path)?;
696
697 assert_eq!(calculated_hash, expected_hash);
698 Ok(())
699 }
700
701 #[rstest]
702 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
703 let temp_dir = TempDir::new()?;
704 let test_file_path = temp_dir.path().join("test_file.txt");
705 let mut test_file = File::create(&test_file_path)?;
706 let content = b"Hello, world!";
707 test_file.write_all(content)?;
708
709 let calculated_checksum = calculate_sha256(&test_file_path)?;
710
711 let checksums_path = temp_dir.path().join("checksums.json");
713 let checksums_data = json!({
714 "test_file.txt": format!("sha256:{}", calculated_checksum)
715 });
716 let checksums_file = File::create(&checksums_path)?;
717 let writer = BufWriter::new(checksums_file);
718 to_writer(writer, &checksums_data)?;
719
720 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
721 assert!(is_valid, "The checksum should be valid");
722 Ok(())
723 }
724}