nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    cmp,
18    fmt::Display,
19    fs::{File, OpenOptions},
20    io::{BufReader, BufWriter, Read, copy},
21    path::Path,
22    thread::sleep,
23    time::{Duration, Instant},
24};
25
26use aws_lc_rs::digest::{self, Context};
27use nautilus_network::retry::RetryConfig;
28use rand::{Rng, rng};
29use reqwest::blocking::Client;
30use serde_json::Value;
31
32#[derive(Debug)]
33enum DownloadError {
34    Retryable(String),
35    NonRetryable(String),
36}
37
38impl Display for DownloadError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
42            Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
43        }
44    }
45}
46
47impl std::error::Error for DownloadError {}
48
49fn execute_with_retry_blocking<T, E, F>(
50    config: &RetryConfig,
51    mut op: F,
52    should_retry: impl Fn(&E) -> bool,
53) -> Result<T, E>
54where
55    E: std::error::Error,
56    F: FnMut() -> Result<T, E>,
57{
58    let start = Instant::now();
59    let mut delay = Duration::from_millis(config.initial_delay_ms);
60
61    for attempt in 0..=config.max_retries {
62        if attempt > 0 && !config.immediate_first {
63            let jitter = rng().random_range(0..=config.jitter_ms);
64            let sleep_for = delay + Duration::from_millis(jitter);
65            sleep(sleep_for);
66            let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
67            delay = cmp::min(
68                Duration::from_millis(next),
69                Duration::from_millis(config.max_delay_ms),
70            );
71        }
72
73        if let Some(max_total) = config.max_elapsed_ms
74            && start.elapsed() >= Duration::from_millis(max_total)
75        {
76            break;
77        }
78
79        match op() {
80            Ok(v) => return Ok(v),
81            Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
82            Err(e) => return Err(e),
83        }
84    }
85
86    op()
87}
88
89/// Ensures that a file exists at the specified path by downloading it if necessary.
90///
91/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
92/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
93/// the checksum is invalid or missing, the function updates the checksums file with the correct
94/// hash for the existing file without redownloading it.
95///
96/// If the file does not exist, it downloads the file from the specified `url` and updates the
97/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
98///
99/// The `timeout_secs` parameter specifies the timeout in seconds for the HTTP request.
100/// If `None` is provided, a default timeout of 30 seconds will be used.
101///
102/// # Errors
103///
104/// Returns an error if:
105/// - The HTTP request cannot be sent or returns a non-success status code.
106/// - Any I/O operation fails during file creation, reading, or writing.
107/// - Checksum verification or JSON parsing fails.
108pub fn ensure_file_exists_or_download_http(
109    filepath: &Path,
110    url: &str,
111    checksums: Option<&Path>,
112    timeout_secs: Option<u64>,
113) -> anyhow::Result<()> {
114    ensure_file_exists_or_download_http_with_config(
115        filepath,
116        url,
117        checksums,
118        timeout_secs.unwrap_or(30),
119        None,
120        None,
121    )
122}
123
124/// Ensures that a file exists at the specified path by downloading it if necessary, with a custom timeout.
125///
126/// # Errors
127///
128/// Returns an error if:
129/// - The HTTP request cannot be sent or returns a non-success status code after retries.
130/// - Any I/O operation fails during file creation, reading, or writing.
131/// - Checksum verification or JSON parsing fails.
132pub fn ensure_file_exists_or_download_http_with_timeout(
133    filepath: &Path,
134    url: &str,
135    checksums: Option<&Path>,
136    timeout_secs: u64,
137) -> anyhow::Result<()> {
138    ensure_file_exists_or_download_http_with_config(
139        filepath,
140        url,
141        checksums,
142        timeout_secs,
143        None,
144        None,
145    )
146}
147
148/// Ensures that a file exists at the specified path by downloading it if necessary,
149/// with custom timeout, retry config, and initial jitter delay.
150///
151/// # Parameters
152///
153/// - `filepath`: The path where the file should exist.
154/// - `url`: The URL to download from if the file doesn't exist.
155/// - `checksums`: Optional path to checksums file for verification.
156/// - `timeout_secs`: Timeout in seconds for HTTP requests.
157/// - `retry_config`: Optional custom retry configuration (uses sensible defaults if None).
158/// - `initial_jitter_ms`: Optional initial jitter delay in milliseconds before download (defaults to 100-600ms if None).
159///
160/// # Errors
161///
162/// Returns an error if:
163/// - The HTTP request cannot be sent or returns a non-success status code after retries.
164/// - Any I/O operation fails during file creation, reading, or writing.
165/// - Checksum verification or JSON parsing fails.
166pub fn ensure_file_exists_or_download_http_with_config(
167    filepath: &Path,
168    url: &str,
169    checksums: Option<&Path>,
170    timeout_secs: u64,
171    retry_config: Option<RetryConfig>,
172    initial_jitter_ms: Option<u64>,
173) -> anyhow::Result<()> {
174    if filepath.exists() {
175        println!("File already exists: {filepath:?}");
176
177        if let Some(checksums_file) = checksums {
178            if verify_sha256_checksum(filepath, checksums_file)? {
179                println!("File is valid");
180                return Ok(());
181            } else {
182                let new_checksum = calculate_sha256(filepath)?;
183                println!("Adding checksum for existing file: {new_checksum}");
184                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
185                return Ok(());
186            }
187        }
188        return Ok(());
189    }
190
191    // Add a small random delay to avoid bursting the remote server when
192    // many downloads start concurrently. Can be disabled by passing Some(0).
193    if let Some(jitter_ms) = initial_jitter_ms {
194        if jitter_ms > 0 {
195            sleep(Duration::from_millis(jitter_ms));
196        }
197    } else {
198        let jitter_delay = {
199            let mut r = rng();
200            Duration::from_millis(r.random_range(100..=600))
201        };
202        sleep(jitter_delay);
203    }
204
205    download_file(filepath, url, timeout_secs, retry_config)?;
206
207    if let Some(checksums_file) = checksums {
208        let new_checksum = calculate_sha256(filepath)?;
209        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
210    }
211
212    Ok(())
213}
214
215fn download_file(
216    filepath: &Path,
217    url: &str,
218    timeout_secs: u64,
219    retry_config: Option<RetryConfig>,
220) -> anyhow::Result<()> {
221    println!("Downloading file from {url} to {filepath:?}");
222
223    if let Some(parent) = filepath.parent() {
224        std::fs::create_dir_all(parent)?;
225    }
226
227    let client = Client::builder()
228        .timeout(Duration::from_secs(timeout_secs))
229        .build()?;
230
231    let cfg = if let Some(config) = retry_config {
232        config
233    } else {
234        // Default production config
235        let max_retries = 5u32;
236        let op_timeout_ms = timeout_secs.saturating_mul(1000);
237        // Make the provided timeout a hard ceiling for total elapsed time.
238        // Split it across attempts (at least 1000 ms per attempt) and cap total at op_timeout_ms.
239        let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
240        RetryConfig {
241            max_retries,
242            initial_delay_ms: 1_000,
243            max_delay_ms: 10_000,
244            backoff_factor: 2.0,
245            jitter_ms: 1_000,
246            operation_timeout_ms: Some(per_attempt_ms),
247            immediate_first: false,
248            max_elapsed_ms: Some(op_timeout_ms),
249        }
250    };
251
252    let op = || -> Result<(), DownloadError> {
253        match client.get(url).send() {
254            Ok(mut response) => {
255                let status = response.status();
256                if status.is_success() {
257                    let mut out = File::create(filepath)
258                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
259                    // Stream the response body directly to disk to avoid large allocations
260                    copy(&mut response, &mut out)
261                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
262                    println!("File downloaded to {filepath:?}");
263                    Ok(())
264                } else if status.is_server_error()
265                    || status.as_u16() == 429
266                    || status.as_u16() == 408
267                {
268                    println!("HTTP error {status}, retrying...");
269                    Err(DownloadError::Retryable(format!("HTTP {status}")))
270                } else {
271                    // Preserve existing error text used by tests
272                    Err(DownloadError::NonRetryable(format!(
273                        "Client error: HTTP {status}"
274                    )))
275                }
276            }
277            Err(e) => {
278                println!("Request failed: {e}");
279                Err(DownloadError::Retryable(e.to_string()))
280            }
281        }
282    };
283
284    let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
285
286    execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
287}
288
289fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
290    let mut file = File::open(filepath)?;
291    let mut ctx = Context::new(&digest::SHA256);
292    let mut buffer = [0u8; 4096];
293
294    loop {
295        let count = file.read(&mut buffer)?;
296        if count == 0 {
297            break;
298        }
299        ctx.update(&buffer[..count]);
300    }
301
302    let digest = ctx.finish();
303    Ok(hex::encode(digest.as_ref()))
304}
305
306fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
307    let file = File::open(checksums)?;
308    let reader = BufReader::new(file);
309    let checksums: Value = serde_json::from_reader(reader)?;
310
311    let filename = filepath.file_name().unwrap().to_str().unwrap();
312    if let Some(expected_checksum) = checksums.get(filename) {
313        let expected_checksum_str = expected_checksum.as_str().unwrap();
314        let expected_hash = expected_checksum_str
315            .strip_prefix("sha256:")
316            .unwrap_or(expected_checksum_str);
317        let calculated_checksum = calculate_sha256(filepath)?;
318        if expected_hash == calculated_checksum {
319            return Ok(true);
320        }
321    }
322
323    Ok(false)
324}
325
326fn update_sha256_checksums(
327    filepath: &Path,
328    checksums_file: &Path,
329    new_checksum: &str,
330) -> anyhow::Result<()> {
331    let checksums: Value = if checksums_file.exists() {
332        let file = File::open(checksums_file)?;
333        let reader = BufReader::new(file);
334        serde_json::from_reader(reader)?
335    } else {
336        serde_json::json!({})
337    };
338
339    let mut checksums_map = checksums.as_object().unwrap().clone();
340
341    // Add or update the checksum
342    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
343    let prefixed_checksum = format!("sha256:{new_checksum}");
344    checksums_map.insert(filename, Value::String(prefixed_checksum));
345
346    let file = OpenOptions::new()
347        .write(true)
348        .create(true)
349        .truncate(true)
350        .open(checksums_file)?;
351    let writer = BufWriter::new(file);
352    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
353
354    Ok(())
355}
356
357////////////////////////////////////////////////////////////////////////////////
358// Tests
359////////////////////////////////////////////////////////////////////////////////
360#[cfg(test)]
361mod tests {
362    use std::{
363        fs,
364        io::{BufWriter, Write},
365        net::SocketAddr,
366        sync::{
367            Arc,
368            atomic::{AtomicUsize, Ordering},
369        },
370    };
371
372    use axum::{Router, http::StatusCode, routing::get, serve};
373    use rstest::*;
374    use serde_json::{json, to_writer};
375    use tempfile::TempDir;
376    use tokio::{
377        net::TcpListener,
378        task,
379        time::{Duration, sleep},
380    };
381
382    use super::*;
383
384    /// Creates a fast, deterministic retry config for tests.
385    /// Uses very short delays to make tests run quickly without introducing flakiness.
386    fn test_retry_config() -> RetryConfig {
387        RetryConfig {
388            max_retries: 5,
389            initial_delay_ms: 10,
390            max_delay_ms: 50,
391            backoff_factor: 2.0,
392            jitter_ms: 5,
393            operation_timeout_ms: Some(500),
394            immediate_first: false,
395            max_elapsed_ms: Some(2000),
396        }
397    }
398
399    async fn setup_test_server(
400        server_content: Option<String>,
401        status_code: StatusCode,
402    ) -> SocketAddr {
403        let server_content = Arc::new(server_content);
404        let server_content_clone = server_content.clone();
405        let app = Router::new().route(
406            "/testfile.txt",
407            get(move || {
408                let server_content = server_content_clone.clone();
409                async move {
410                    let response_body = match &*server_content {
411                        Some(content) => content.clone(),
412                        None => "File not found".to_string(),
413                    };
414                    (status_code, response_body)
415                }
416            }),
417        );
418
419        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
420        let addr = listener.local_addr().unwrap();
421        let server = serve(listener, app);
422
423        task::spawn(async move {
424            if let Err(e) = server.await {
425                eprintln!("server error: {e}");
426            }
427        });
428
429        sleep(Duration::from_millis(100)).await;
430
431        addr
432    }
433
434    #[tokio::test]
435    async fn test_file_already_exists() {
436        let temp_dir = TempDir::new().unwrap();
437        let file_path = temp_dir.path().join("testfile.txt");
438        fs::write(&file_path, "Existing file content").unwrap();
439
440        let url = "http://example.com/testfile.txt".to_string();
441        let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
442
443        assert!(result.is_ok());
444        let content = fs::read_to_string(&file_path).unwrap();
445        assert_eq!(content, "Existing file content");
446    }
447
448    #[tokio::test]
449    async fn test_download_file_success() {
450        let temp_dir = TempDir::new().unwrap();
451        let filepath = temp_dir.path().join("testfile.txt");
452        let filepath_clone = filepath.clone();
453
454        let server_content = "Server file content".to_string();
455        let status_code = StatusCode::OK;
456        let addr = setup_test_server(Some(server_content.clone()), status_code).await;
457        let url = format!("http://{addr}/testfile.txt");
458
459        let result = tokio::task::spawn_blocking(move || {
460            ensure_file_exists_or_download_http_with_config(
461                &filepath_clone,
462                &url,
463                None,
464                5,
465                Some(test_retry_config()),
466                Some(0),
467            )
468        })
469        .await
470        .unwrap();
471
472        assert!(result.is_ok());
473        let content = fs::read_to_string(&filepath).unwrap();
474        assert_eq!(content, server_content);
475    }
476
477    #[tokio::test]
478    async fn test_download_file_not_found() {
479        let temp_dir = TempDir::new().unwrap();
480        let file_path = temp_dir.path().join("testfile.txt");
481
482        let server_content = None;
483        let status_code = StatusCode::NOT_FOUND;
484        let addr = setup_test_server(server_content, status_code).await;
485        let url = format!("http://{addr}/testfile.txt");
486
487        let result = tokio::task::spawn_blocking(move || {
488            ensure_file_exists_or_download_http_with_config(
489                &file_path,
490                &url,
491                None,
492                1,
493                Some(test_retry_config()),
494                Some(0),
495            )
496        })
497        .await
498        .unwrap();
499
500        assert!(result.is_err());
501        let err_msg = format!("{}", result.unwrap_err());
502        assert!(
503            err_msg.contains("Client error: HTTP"),
504            "Unexpected error message: {err_msg}"
505        );
506    }
507
508    #[tokio::test]
509    async fn test_network_error() {
510        let temp_dir = TempDir::new().unwrap();
511        let file_path = temp_dir.path().join("testfile.txt");
512
513        // Use an unreachable address to simulate a network error
514        let url = "http://127.0.0.1:0/testfile.txt".to_string();
515
516        let result = tokio::task::spawn_blocking(move || {
517            ensure_file_exists_or_download_http_with_config(
518                &file_path,
519                &url,
520                None,
521                2,
522                Some(test_retry_config()),
523                Some(0),
524            )
525        })
526        .await
527        .unwrap();
528
529        assert!(result.is_err());
530        let err_msg = format!("{}", result.unwrap_err());
531        assert!(
532            err_msg.contains("error"),
533            "Unexpected error message: {err_msg}"
534        );
535    }
536
537    #[tokio::test]
538    async fn test_retry_then_success_on_500() {
539        let temp_dir = TempDir::new().unwrap();
540        let filepath = temp_dir.path().join("testfile.txt");
541        let filepath_clone = filepath.clone();
542
543        let counter = Arc::new(AtomicUsize::new(0));
544        let counter_clone = counter.clone();
545
546        let app = Router::new().route(
547            "/testfile.txt",
548            get(move || {
549                let c = counter_clone.clone();
550                async move {
551                    let n = c.fetch_add(1, Ordering::SeqCst);
552                    if n < 2 {
553                        (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
554                    } else {
555                        (StatusCode::OK, "eventual success")
556                    }
557                }
558            }),
559        );
560
561        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
562        let addr = listener.local_addr().unwrap();
563        let server = serve(listener, app);
564        task::spawn(async move {
565            let _ = server.await;
566        });
567        sleep(Duration::from_millis(100)).await;
568
569        let url = format!("http://{addr}/testfile.txt");
570
571        let result = tokio::task::spawn_blocking(move || {
572            ensure_file_exists_or_download_http_with_config(
573                &filepath_clone,
574                &url,
575                None,
576                5,
577                Some(test_retry_config()),
578                Some(0),
579            )
580        })
581        .await
582        .unwrap();
583
584        assert!(result.is_ok());
585        let content = std::fs::read_to_string(&filepath).unwrap();
586        assert_eq!(content, "eventual success");
587        assert!(counter.load(Ordering::SeqCst) >= 2);
588    }
589
590    #[tokio::test]
591    async fn test_retry_then_success_on_429() {
592        let temp_dir = TempDir::new().unwrap();
593        let filepath = temp_dir.path().join("testfile.txt");
594        let filepath_clone = filepath.clone();
595
596        let counter = Arc::new(AtomicUsize::new(0));
597        let counter_clone = counter.clone();
598
599        let app = Router::new().route(
600            "/testfile.txt",
601            get(move || {
602                let c = counter_clone.clone();
603                async move {
604                    let n = c.fetch_add(1, Ordering::SeqCst);
605                    if n < 1 {
606                        (StatusCode::TOO_MANY_REQUESTS, "rate limited")
607                    } else {
608                        (StatusCode::OK, "ok after retry")
609                    }
610                }
611            }),
612        );
613
614        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
615        let addr = listener.local_addr().unwrap();
616        let server = serve(listener, app);
617        task::spawn(async move {
618            let _ = server.await;
619        });
620        sleep(Duration::from_millis(100)).await;
621
622        let url = format!("http://{addr}/testfile.txt");
623
624        let result = tokio::task::spawn_blocking(move || {
625            ensure_file_exists_or_download_http_with_config(
626                &filepath_clone,
627                &url,
628                None,
629                5,
630                Some(test_retry_config()),
631                Some(0),
632            )
633        })
634        .await
635        .unwrap();
636
637        assert!(result.is_ok());
638        let content = std::fs::read_to_string(&filepath).unwrap();
639        assert_eq!(content, "ok after retry");
640        assert!(counter.load(Ordering::SeqCst) >= 2);
641    }
642
643    #[tokio::test]
644    async fn test_no_retry_on_404() {
645        let temp_dir = TempDir::new().unwrap();
646        let filepath = temp_dir.path().join("testfile.txt");
647        let filepath_clone = filepath.clone();
648
649        let counter = Arc::new(AtomicUsize::new(0));
650        let counter_clone = counter.clone();
651
652        let app = Router::new().route(
653            "/testfile.txt",
654            get(move || {
655                let c = counter_clone.clone();
656                async move {
657                    c.fetch_add(1, Ordering::SeqCst);
658                    (StatusCode::NOT_FOUND, "missing")
659                }
660            }),
661        );
662
663        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
664        let addr = listener.local_addr().unwrap();
665        let server = serve(listener, app);
666        task::spawn(async move {
667            let _ = server.await;
668        });
669        sleep(Duration::from_millis(100)).await;
670
671        let url = format!("http://{addr}/testfile.txt");
672
673        let result = tokio::task::spawn_blocking(move || {
674            ensure_file_exists_or_download_http_with_config(
675                &filepath_clone,
676                &url,
677                None,
678                5,
679                Some(test_retry_config()),
680                Some(0),
681            )
682        })
683        .await
684        .unwrap();
685
686        assert!(result.is_err());
687        assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
688    }
689
690    #[rstest]
691    #[allow(clippy::panic_in_result_fn)]
692    fn test_calculate_sha256() -> anyhow::Result<()> {
693        let temp_dir = TempDir::new()?;
694        let test_file_path = temp_dir.path().join("test_file.txt");
695        let mut test_file = File::create(&test_file_path)?;
696        let content = b"Hello, world!";
697        test_file.write_all(content)?;
698
699        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
700        let calculated_hash = calculate_sha256(&test_file_path)?;
701
702        assert_eq!(calculated_hash, expected_hash);
703        Ok(())
704    }
705
706    #[rstest]
707    #[allow(clippy::panic_in_result_fn)]
708    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
709        let temp_dir = TempDir::new()?;
710        let test_file_path = temp_dir.path().join("test_file.txt");
711        let mut test_file = File::create(&test_file_path)?;
712        let content = b"Hello, world!";
713        test_file.write_all(content)?;
714
715        let calculated_checksum = calculate_sha256(&test_file_path)?;
716
717        // Create checksums.json containing the checksum
718        let checksums_path = temp_dir.path().join("checksums.json");
719        let checksums_data = json!({
720            "test_file.txt": format!("sha256:{}", calculated_checksum)
721        });
722        let checksums_file = File::create(&checksums_path)?;
723        let writer = BufWriter::new(checksums_file);
724        to_writer(writer, &checksums_data)?;
725
726        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
727        assert!(is_valid, "The checksum should be valid");
728        Ok(())
729    }
730}