nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    cmp,
18    fs::{File, OpenOptions},
19    io::{BufReader, BufWriter, Read, copy},
20    path::Path,
21    thread::sleep,
22    time::{Duration, Instant},
23};
24
25use aws_lc_rs::digest::{self, Context};
26use nautilus_network::retry::RetryConfig;
27use rand::{Rng, rng};
28use reqwest::blocking::Client;
29use serde_json::Value;
30
31#[derive(Debug)]
32enum DownloadError {
33    Retryable(String),
34    NonRetryable(String),
35}
36
37impl std::fmt::Display for DownloadError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
41            Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for DownloadError {}
47
48fn execute_with_retry_blocking<T, E, F>(
49    config: &RetryConfig,
50    mut op: F,
51    should_retry: impl Fn(&E) -> bool,
52) -> Result<T, E>
53where
54    E: std::error::Error,
55    F: FnMut() -> Result<T, E>,
56{
57    let start = Instant::now();
58    let mut delay = Duration::from_millis(config.initial_delay_ms);
59
60    for attempt in 0..=config.max_retries {
61        if attempt > 0 && !config.immediate_first {
62            let jitter = rng().random_range(0..=config.jitter_ms);
63            let sleep_for = delay + Duration::from_millis(jitter);
64            sleep(sleep_for);
65            let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
66            delay = cmp::min(
67                Duration::from_millis(next),
68                Duration::from_millis(config.max_delay_ms),
69            );
70        }
71
72        if let Some(max_total) = config.max_elapsed_ms
73            && start.elapsed() >= Duration::from_millis(max_total)
74        {
75            break;
76        }
77
78        match op() {
79            Ok(v) => return Ok(v),
80            Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
81            Err(e) => return Err(e),
82        }
83    }
84
85    op()
86}
87
88/// Ensures that a file exists at the specified path by downloading it if necessary.
89///
90/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
91/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
92/// the checksum is invalid or missing, the function updates the checksums file with the correct
93/// hash for the existing file without redownloading it.
94///
95/// If the file does not exist, it downloads the file from the specified `url` and updates the
96/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
97///
98/// The `timeout_secs` parameter specifies the timeout in seconds for the HTTP request.
99/// If `None` is provided, a default timeout of 30 seconds will be used.
100///
101/// # Errors
102///
103/// Returns an error if:
104/// - The HTTP request cannot be sent or returns a non-success status code.
105/// - Any I/O operation fails during file creation, reading, or writing.
106/// - Checksum verification or JSON parsing fails.
107pub fn ensure_file_exists_or_download_http(
108    filepath: &Path,
109    url: &str,
110    checksums: Option<&Path>,
111    timeout_secs: Option<u64>,
112) -> anyhow::Result<()> {
113    ensure_file_exists_or_download_http_with_config(
114        filepath,
115        url,
116        checksums,
117        timeout_secs.unwrap_or(30),
118        None,
119        None,
120    )
121}
122
123/// Ensures that a file exists at the specified path by downloading it if necessary, with a custom timeout.
124///
125/// # Errors
126///
127/// Returns an error if:
128/// - The HTTP request cannot be sent or returns a non-success status code after retries.
129/// - Any I/O operation fails during file creation, reading, or writing.
130/// - Checksum verification or JSON parsing fails.
131pub fn ensure_file_exists_or_download_http_with_timeout(
132    filepath: &Path,
133    url: &str,
134    checksums: Option<&Path>,
135    timeout_secs: u64,
136) -> anyhow::Result<()> {
137    ensure_file_exists_or_download_http_with_config(
138        filepath,
139        url,
140        checksums,
141        timeout_secs,
142        None,
143        None,
144    )
145}
146
147/// Ensures that a file exists at the specified path by downloading it if necessary,
148/// with custom timeout, retry config, and initial jitter delay.
149///
150/// # Parameters
151///
152/// - `filepath`: The path where the file should exist.
153/// - `url`: The URL to download from if the file doesn't exist.
154/// - `checksums`: Optional path to checksums file for verification.
155/// - `timeout_secs`: Timeout in seconds for HTTP requests.
156/// - `retry_config`: Optional custom retry configuration (uses sensible defaults if None).
157/// - `initial_jitter_ms`: Optional initial jitter delay in milliseconds before download (defaults to 100-600ms if None).
158///
159/// # Errors
160///
161/// Returns an error if:
162/// - The HTTP request cannot be sent or returns a non-success status code after retries.
163/// - Any I/O operation fails during file creation, reading, or writing.
164/// - Checksum verification or JSON parsing fails.
165pub fn ensure_file_exists_or_download_http_with_config(
166    filepath: &Path,
167    url: &str,
168    checksums: Option<&Path>,
169    timeout_secs: u64,
170    retry_config: Option<RetryConfig>,
171    initial_jitter_ms: Option<u64>,
172) -> anyhow::Result<()> {
173    if filepath.exists() {
174        println!("File already exists: {filepath:?}");
175
176        if let Some(checksums_file) = checksums {
177            if verify_sha256_checksum(filepath, checksums_file)? {
178                println!("File is valid");
179                return Ok(());
180            } else {
181                let new_checksum = calculate_sha256(filepath)?;
182                println!("Adding checksum for existing file: {new_checksum}");
183                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
184                return Ok(());
185            }
186        }
187        return Ok(());
188    }
189
190    // Add a small random delay to avoid bursting the remote server when
191    // many downloads start concurrently. Can be disabled by passing Some(0).
192    if let Some(jitter_ms) = initial_jitter_ms {
193        if jitter_ms > 0 {
194            sleep(Duration::from_millis(jitter_ms));
195        }
196    } else {
197        let jitter_delay = {
198            let mut r = rng();
199            Duration::from_millis(r.random_range(100..=600))
200        };
201        sleep(jitter_delay);
202    }
203
204    download_file(filepath, url, timeout_secs, retry_config)?;
205
206    if let Some(checksums_file) = checksums {
207        let new_checksum = calculate_sha256(filepath)?;
208        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
209    }
210
211    Ok(())
212}
213
214fn download_file(
215    filepath: &Path,
216    url: &str,
217    timeout_secs: u64,
218    retry_config: Option<RetryConfig>,
219) -> anyhow::Result<()> {
220    println!("Downloading file from {url} to {filepath:?}");
221
222    if let Some(parent) = filepath.parent() {
223        std::fs::create_dir_all(parent)?;
224    }
225
226    let client = Client::builder()
227        .timeout(Duration::from_secs(timeout_secs))
228        .build()?;
229
230    let cfg = if let Some(config) = retry_config {
231        config
232    } else {
233        // Default production config
234        let max_retries = 5u32;
235        let op_timeout_ms = timeout_secs.saturating_mul(1000);
236        // Make the provided timeout a hard ceiling for total elapsed time.
237        // Split it across attempts (at least 1000 ms per attempt) and cap total at op_timeout_ms.
238        let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
239        RetryConfig {
240            max_retries,
241            initial_delay_ms: 1_000,
242            max_delay_ms: 10_000,
243            backoff_factor: 2.0,
244            jitter_ms: 1_000,
245            operation_timeout_ms: Some(per_attempt_ms),
246            immediate_first: false,
247            max_elapsed_ms: Some(op_timeout_ms),
248        }
249    };
250
251    let op = || -> Result<(), DownloadError> {
252        match client.get(url).send() {
253            Ok(mut response) => {
254                let status = response.status();
255                if status.is_success() {
256                    let mut out = File::create(filepath)
257                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
258                    // Stream the response body directly to disk to avoid large allocations
259                    copy(&mut response, &mut out)
260                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
261                    println!("File downloaded to {filepath:?}");
262                    Ok(())
263                } else if status.is_server_error()
264                    || status.as_u16() == 429
265                    || status.as_u16() == 408
266                {
267                    println!("HTTP error {status}, retrying...");
268                    Err(DownloadError::Retryable(format!("HTTP {status}")))
269                } else {
270                    // Preserve existing error text used by tests
271                    Err(DownloadError::NonRetryable(format!(
272                        "Client error: HTTP {status}"
273                    )))
274                }
275            }
276            Err(e) => {
277                println!("Request failed: {e}");
278                Err(DownloadError::Retryable(e.to_string()))
279            }
280        }
281    };
282
283    let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
284
285    execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
286}
287
288fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
289    let mut file = File::open(filepath)?;
290    let mut ctx = Context::new(&digest::SHA256);
291    let mut buffer = [0u8; 4096];
292
293    loop {
294        let count = file.read(&mut buffer)?;
295        if count == 0 {
296            break;
297        }
298        ctx.update(&buffer[..count]);
299    }
300
301    let digest = ctx.finish();
302    Ok(hex::encode(digest.as_ref()))
303}
304
305fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
306    let file = File::open(checksums)?;
307    let reader = BufReader::new(file);
308    let checksums: Value = serde_json::from_reader(reader)?;
309
310    let filename = filepath.file_name().unwrap().to_str().unwrap();
311    if let Some(expected_checksum) = checksums.get(filename) {
312        let expected_checksum_str = expected_checksum.as_str().unwrap();
313        let expected_hash = expected_checksum_str
314            .strip_prefix("sha256:")
315            .unwrap_or(expected_checksum_str);
316        let calculated_checksum = calculate_sha256(filepath)?;
317        if expected_hash == calculated_checksum {
318            return Ok(true);
319        }
320    }
321
322    Ok(false)
323}
324
325fn update_sha256_checksums(
326    filepath: &Path,
327    checksums_file: &Path,
328    new_checksum: &str,
329) -> anyhow::Result<()> {
330    let checksums: Value = if checksums_file.exists() {
331        let file = File::open(checksums_file)?;
332        let reader = BufReader::new(file);
333        serde_json::from_reader(reader)?
334    } else {
335        serde_json::json!({})
336    };
337
338    let mut checksums_map = checksums.as_object().unwrap().clone();
339
340    // Add or update the checksum
341    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
342    let prefixed_checksum = format!("sha256:{new_checksum}");
343    checksums_map.insert(filename, Value::String(prefixed_checksum));
344
345    let file = OpenOptions::new()
346        .write(true)
347        .create(true)
348        .truncate(true)
349        .open(checksums_file)?;
350    let writer = BufWriter::new(file);
351    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
352
353    Ok(())
354}
355
356////////////////////////////////////////////////////////////////////////////////
357// Tests
358////////////////////////////////////////////////////////////////////////////////
359#[cfg(test)]
360mod tests {
361    use std::{
362        fs,
363        io::{BufWriter, Write},
364        net::SocketAddr,
365        sync::{
366            Arc,
367            atomic::{AtomicUsize, Ordering},
368        },
369    };
370
371    use axum::{Router, http::StatusCode, routing::get, serve};
372    use rstest::*;
373    use serde_json::{json, to_writer};
374    use tempfile::TempDir;
375    use tokio::{
376        net::TcpListener,
377        task,
378        time::{Duration, sleep},
379    };
380
381    use super::*;
382
383    /// Creates a fast, deterministic retry config for tests.
384    /// Uses very short delays to make tests run quickly without introducing flakiness.
385    fn test_retry_config() -> RetryConfig {
386        RetryConfig {
387            max_retries: 5,
388            initial_delay_ms: 10,
389            max_delay_ms: 50,
390            backoff_factor: 2.0,
391            jitter_ms: 5,
392            operation_timeout_ms: Some(500),
393            immediate_first: false,
394            max_elapsed_ms: Some(2000),
395        }
396    }
397
398    async fn setup_test_server(
399        server_content: Option<String>,
400        status_code: StatusCode,
401    ) -> SocketAddr {
402        let server_content = Arc::new(server_content);
403        let server_content_clone = server_content.clone();
404        let app = Router::new().route(
405            "/testfile.txt",
406            get(move || {
407                let server_content = server_content_clone.clone();
408                async move {
409                    let response_body = match &*server_content {
410                        Some(content) => content.clone(),
411                        None => "File not found".to_string(),
412                    };
413                    (status_code, response_body)
414                }
415            }),
416        );
417
418        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
419        let addr = listener.local_addr().unwrap();
420        let server = serve(listener, app);
421
422        task::spawn(async move {
423            if let Err(e) = server.await {
424                eprintln!("server error: {e}");
425            }
426        });
427
428        sleep(Duration::from_millis(100)).await;
429
430        addr
431    }
432
433    #[tokio::test]
434    async fn test_file_already_exists() {
435        let temp_dir = TempDir::new().unwrap();
436        let file_path = temp_dir.path().join("testfile.txt");
437        fs::write(&file_path, "Existing file content").unwrap();
438
439        let url = "http://example.com/testfile.txt".to_string();
440        let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
441
442        assert!(result.is_ok());
443        let content = fs::read_to_string(&file_path).unwrap();
444        assert_eq!(content, "Existing file content");
445    }
446
447    #[tokio::test]
448    async fn test_download_file_success() {
449        let temp_dir = TempDir::new().unwrap();
450        let filepath = temp_dir.path().join("testfile.txt");
451        let filepath_clone = filepath.clone();
452
453        let server_content = "Server file content".to_string();
454        let status_code = StatusCode::OK;
455        let addr = setup_test_server(Some(server_content.clone()), status_code).await;
456        let url = format!("http://{addr}/testfile.txt");
457
458        let result = tokio::task::spawn_blocking(move || {
459            ensure_file_exists_or_download_http_with_config(
460                &filepath_clone,
461                &url,
462                None,
463                5,
464                Some(test_retry_config()),
465                Some(0),
466            )
467        })
468        .await
469        .unwrap();
470
471        assert!(result.is_ok());
472        let content = fs::read_to_string(&filepath).unwrap();
473        assert_eq!(content, server_content);
474    }
475
476    #[tokio::test]
477    async fn test_download_file_not_found() {
478        let temp_dir = TempDir::new().unwrap();
479        let file_path = temp_dir.path().join("testfile.txt");
480
481        let server_content = None;
482        let status_code = StatusCode::NOT_FOUND;
483        let addr = setup_test_server(server_content, status_code).await;
484        let url = format!("http://{addr}/testfile.txt");
485
486        let result = tokio::task::spawn_blocking(move || {
487            ensure_file_exists_or_download_http_with_config(
488                &file_path,
489                &url,
490                None,
491                1,
492                Some(test_retry_config()),
493                Some(0),
494            )
495        })
496        .await
497        .unwrap();
498
499        assert!(result.is_err());
500        let err_msg = format!("{}", result.unwrap_err());
501        assert!(
502            err_msg.contains("Client error: HTTP"),
503            "Unexpected error message: {err_msg}"
504        );
505    }
506
507    #[tokio::test]
508    async fn test_network_error() {
509        let temp_dir = TempDir::new().unwrap();
510        let file_path = temp_dir.path().join("testfile.txt");
511
512        // Use an unreachable address to simulate a network error
513        let url = "http://127.0.0.1:0/testfile.txt".to_string();
514
515        let result = tokio::task::spawn_blocking(move || {
516            ensure_file_exists_or_download_http_with_config(
517                &file_path,
518                &url,
519                None,
520                2,
521                Some(test_retry_config()),
522                Some(0),
523            )
524        })
525        .await
526        .unwrap();
527
528        assert!(result.is_err());
529        let err_msg = format!("{}", result.unwrap_err());
530        assert!(
531            err_msg.contains("error"),
532            "Unexpected error message: {err_msg}"
533        );
534    }
535
536    #[tokio::test]
537    async fn test_retry_then_success_on_500() {
538        let temp_dir = TempDir::new().unwrap();
539        let filepath = temp_dir.path().join("testfile.txt");
540        let filepath_clone = filepath.clone();
541
542        let counter = Arc::new(AtomicUsize::new(0));
543        let counter_clone = counter.clone();
544
545        let app = Router::new().route(
546            "/testfile.txt",
547            get(move || {
548                let c = counter_clone.clone();
549                async move {
550                    let n = c.fetch_add(1, Ordering::SeqCst);
551                    if n < 2 {
552                        (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
553                    } else {
554                        (StatusCode::OK, "eventual success")
555                    }
556                }
557            }),
558        );
559
560        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
561        let addr = listener.local_addr().unwrap();
562        let server = serve(listener, app);
563        task::spawn(async move {
564            let _ = server.await;
565        });
566        sleep(Duration::from_millis(100)).await;
567
568        let url = format!("http://{addr}/testfile.txt");
569        let result = tokio::task::spawn_blocking(move || {
570            ensure_file_exists_or_download_http_with_config(
571                &filepath_clone,
572                &url,
573                None,
574                5,
575                Some(test_retry_config()),
576                Some(0),
577            )
578        })
579        .await
580        .unwrap();
581
582        assert!(result.is_ok());
583        let content = std::fs::read_to_string(&filepath).unwrap();
584        assert_eq!(content, "eventual success");
585        assert!(counter.load(Ordering::SeqCst) >= 2);
586    }
587
588    #[tokio::test]
589    async fn test_retry_then_success_on_429() {
590        let temp_dir = TempDir::new().unwrap();
591        let filepath = temp_dir.path().join("testfile.txt");
592        let filepath_clone = filepath.clone();
593
594        let counter = Arc::new(AtomicUsize::new(0));
595        let counter_clone = counter.clone();
596
597        let app = Router::new().route(
598            "/testfile.txt",
599            get(move || {
600                let c = counter_clone.clone();
601                async move {
602                    let n = c.fetch_add(1, Ordering::SeqCst);
603                    if n < 1 {
604                        (StatusCode::TOO_MANY_REQUESTS, "rate limited")
605                    } else {
606                        (StatusCode::OK, "ok after retry")
607                    }
608                }
609            }),
610        );
611
612        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
613        let addr = listener.local_addr().unwrap();
614        let server = serve(listener, app);
615        task::spawn(async move {
616            let _ = server.await;
617        });
618        sleep(Duration::from_millis(100)).await;
619
620        let url = format!("http://{addr}/testfile.txt");
621        let result = tokio::task::spawn_blocking(move || {
622            ensure_file_exists_or_download_http_with_config(
623                &filepath_clone,
624                &url,
625                None,
626                5,
627                Some(test_retry_config()),
628                Some(0),
629            )
630        })
631        .await
632        .unwrap();
633
634        assert!(result.is_ok());
635        let content = std::fs::read_to_string(&filepath).unwrap();
636        assert_eq!(content, "ok after retry");
637        assert!(counter.load(Ordering::SeqCst) >= 2);
638    }
639
640    #[tokio::test]
641    async fn test_no_retry_on_404() {
642        let temp_dir = TempDir::new().unwrap();
643        let filepath = temp_dir.path().join("testfile.txt");
644        let filepath_clone = filepath.clone();
645
646        let counter = Arc::new(AtomicUsize::new(0));
647        let counter_clone = counter.clone();
648
649        let app = Router::new().route(
650            "/testfile.txt",
651            get(move || {
652                let c = counter_clone.clone();
653                async move {
654                    c.fetch_add(1, Ordering::SeqCst);
655                    (StatusCode::NOT_FOUND, "missing")
656                }
657            }),
658        );
659
660        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
661        let addr = listener.local_addr().unwrap();
662        let server = serve(listener, app);
663        task::spawn(async move {
664            let _ = server.await;
665        });
666        sleep(Duration::from_millis(100)).await;
667
668        let url = format!("http://{addr}/testfile.txt");
669        let result = tokio::task::spawn_blocking(move || {
670            ensure_file_exists_or_download_http_with_config(
671                &filepath_clone,
672                &url,
673                None,
674                5,
675                Some(test_retry_config()),
676                Some(0),
677            )
678        })
679        .await
680        .unwrap();
681
682        assert!(result.is_err());
683        assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
684    }
685
686    #[rstest]
687    fn test_calculate_sha256() -> anyhow::Result<()> {
688        let temp_dir = TempDir::new()?;
689        let test_file_path = temp_dir.path().join("test_file.txt");
690        let mut test_file = File::create(&test_file_path)?;
691        let content = b"Hello, world!";
692        test_file.write_all(content)?;
693
694        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
695        let calculated_hash = calculate_sha256(&test_file_path)?;
696
697        assert_eq!(calculated_hash, expected_hash);
698        Ok(())
699    }
700
701    #[rstest]
702    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
703        let temp_dir = TempDir::new()?;
704        let test_file_path = temp_dir.path().join("test_file.txt");
705        let mut test_file = File::create(&test_file_path)?;
706        let content = b"Hello, world!";
707        test_file.write_all(content)?;
708
709        let calculated_checksum = calculate_sha256(&test_file_path)?;
710
711        // Create checksums.json containing the checksum
712        let checksums_path = temp_dir.path().join("checksums.json");
713        let checksums_data = json!({
714            "test_file.txt": format!("sha256:{}", calculated_checksum)
715        });
716        let checksums_file = File::create(&checksums_path)?;
717        let writer = BufWriter::new(checksums_file);
718        to_writer(writer, &checksums_data)?;
719
720        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
721        assert!(is_valid, "The checksum should be valid");
722        Ok(())
723    }
724}