nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    cmp,
18    fmt::Display,
19    fs::{File, OpenOptions},
20    io::{BufReader, BufWriter, Read, copy},
21    path::Path,
22    thread::sleep,
23    time::{Duration, Instant},
24};
25
26use aws_lc_rs::digest::{self, Context};
27use nautilus_network::retry::RetryConfig;
28use rand::{Rng, rng};
29use reqwest::blocking::Client;
30use serde_json::Value;
31
32#[derive(Debug)]
33enum DownloadError {
34    Retryable(String),
35    NonRetryable(String),
36}
37
38impl Display for DownloadError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
42            Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
43        }
44    }
45}
46
47impl std::error::Error for DownloadError {}
48
49fn execute_with_retry_blocking<T, E, F>(
50    config: &RetryConfig,
51    mut op: F,
52    should_retry: impl Fn(&E) -> bool,
53) -> Result<T, E>
54where
55    E: std::error::Error,
56    F: FnMut() -> Result<T, E>,
57{
58    let start = Instant::now();
59    let mut delay = Duration::from_millis(config.initial_delay_ms);
60
61    for attempt in 0..=config.max_retries {
62        if attempt > 0 && !config.immediate_first {
63            let jitter = rng().random_range(0..=config.jitter_ms);
64            let sleep_for = delay + Duration::from_millis(jitter);
65            sleep(sleep_for);
66            let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
67            delay = cmp::min(
68                Duration::from_millis(next),
69                Duration::from_millis(config.max_delay_ms),
70            );
71        }
72
73        if let Some(max_total) = config.max_elapsed_ms
74            && start.elapsed() >= Duration::from_millis(max_total)
75        {
76            break;
77        }
78
79        match op() {
80            Ok(v) => return Ok(v),
81            Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
82            Err(e) => return Err(e),
83        }
84    }
85
86    op()
87}
88
89/// Ensures that a file exists at the specified path by downloading it if necessary.
90///
91/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
92/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
93/// the checksum is invalid or missing, the function updates the checksums file with the correct
94/// hash for the existing file without redownloading it.
95///
96/// If the file does not exist, it downloads the file from the specified `url` and updates the
97/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
98///
99/// The `timeout_secs` parameter specifies the timeout in seconds for the HTTP request.
100/// If `None` is provided, a default timeout of 30 seconds will be used.
101///
102/// # Errors
103///
104/// Returns an error if:
105/// - The HTTP request cannot be sent or returns a non-success status code.
106/// - Any I/O operation fails during file creation, reading, or writing.
107/// - Checksum verification or JSON parsing fails.
108pub fn ensure_file_exists_or_download_http(
109    filepath: &Path,
110    url: &str,
111    checksums: Option<&Path>,
112    timeout_secs: Option<u64>,
113) -> anyhow::Result<()> {
114    ensure_file_exists_or_download_http_with_config(
115        filepath,
116        url,
117        checksums,
118        timeout_secs.unwrap_or(30),
119        None,
120        None,
121    )
122}
123
124/// Ensures that a file exists at the specified path by downloading it if necessary, with a custom timeout.
125///
126/// # Errors
127///
128/// Returns an error if:
129/// - The HTTP request cannot be sent or returns a non-success status code after retries.
130/// - Any I/O operation fails during file creation, reading, or writing.
131/// - Checksum verification or JSON parsing fails.
132pub fn ensure_file_exists_or_download_http_with_timeout(
133    filepath: &Path,
134    url: &str,
135    checksums: Option<&Path>,
136    timeout_secs: u64,
137) -> anyhow::Result<()> {
138    ensure_file_exists_or_download_http_with_config(
139        filepath,
140        url,
141        checksums,
142        timeout_secs,
143        None,
144        None,
145    )
146}
147
148/// Ensures that a file exists at the specified path by downloading it if necessary,
149/// with custom timeout, retry config, and initial jitter delay.
150///
151/// # Parameters
152///
153/// - `filepath`: The path where the file should exist.
154/// - `url`: The URL to download from if the file doesn't exist.
155/// - `checksums`: Optional path to checksums file for verification.
156/// - `timeout_secs`: Timeout in seconds for HTTP requests.
157/// - `retry_config`: Optional custom retry configuration (uses sensible defaults if None).
158/// - `initial_jitter_ms`: Optional initial jitter delay in milliseconds before download (defaults to 100-600ms if None).
159///
160/// # Errors
161///
162/// Returns an error if:
163/// - The HTTP request cannot be sent or returns a non-success status code after retries.
164/// - Any I/O operation fails during file creation, reading, or writing.
165/// - Checksum verification or JSON parsing fails.
166pub fn ensure_file_exists_or_download_http_with_config(
167    filepath: &Path,
168    url: &str,
169    checksums: Option<&Path>,
170    timeout_secs: u64,
171    retry_config: Option<RetryConfig>,
172    initial_jitter_ms: Option<u64>,
173) -> anyhow::Result<()> {
174    if filepath.exists() {
175        println!("File already exists: {filepath:?}");
176
177        if let Some(checksums_file) = checksums {
178            if verify_sha256_checksum(filepath, checksums_file)? {
179                println!("File is valid");
180                return Ok(());
181            } else {
182                let new_checksum = calculate_sha256(filepath)?;
183                println!("Adding checksum for existing file: {new_checksum}");
184                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
185                return Ok(());
186            }
187        }
188        return Ok(());
189    }
190
191    // Add a small random delay to avoid bursting the remote server when
192    // many downloads start concurrently. Can be disabled by passing Some(0).
193    if let Some(jitter_ms) = initial_jitter_ms {
194        if jitter_ms > 0 {
195            sleep(Duration::from_millis(jitter_ms));
196        }
197    } else {
198        let jitter_delay = {
199            let mut r = rng();
200            Duration::from_millis(r.random_range(100..=600))
201        };
202        sleep(jitter_delay);
203    }
204
205    download_file(filepath, url, timeout_secs, retry_config)?;
206
207    if let Some(checksums_file) = checksums {
208        let new_checksum = calculate_sha256(filepath)?;
209        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
210    }
211
212    Ok(())
213}
214
215fn download_file(
216    filepath: &Path,
217    url: &str,
218    timeout_secs: u64,
219    retry_config: Option<RetryConfig>,
220) -> anyhow::Result<()> {
221    // Validate HTTPS for security in production builds,
222    // HTTP is intentionally allowed in test builds for local test servers (127.0.0.1),
223    // CodeQL flags this as "non-https-url" but it's a deliberate design choice for testkit.
224    #[cfg(not(test))]
225    if !url.starts_with("https://") {
226        anyhow::bail!("URL must use HTTPS protocol for security: {url}");
227    }
228
229    println!("Downloading file from {url} to {filepath:?}");
230
231    if let Some(parent) = filepath.parent() {
232        std::fs::create_dir_all(parent)?;
233    }
234
235    let client = Client::builder()
236        .timeout(Duration::from_secs(timeout_secs))
237        .build()?;
238
239    let cfg = if let Some(config) = retry_config {
240        config
241    } else {
242        // Default production config
243        let max_retries = 5u32;
244        let op_timeout_ms = timeout_secs.saturating_mul(1000);
245        // Make the provided timeout a hard ceiling for total elapsed time.
246        // Split it across attempts (at least 1000 ms per attempt) and cap total at op_timeout_ms.
247        let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
248        RetryConfig {
249            max_retries,
250            initial_delay_ms: 1_000,
251            max_delay_ms: 10_000,
252            backoff_factor: 2.0,
253            jitter_ms: 1_000,
254            operation_timeout_ms: Some(per_attempt_ms),
255            immediate_first: false,
256            max_elapsed_ms: Some(op_timeout_ms),
257        }
258    };
259
260    let op = || -> Result<(), DownloadError> {
261        match client.get(url).send() {
262            Ok(mut response) => {
263                let status = response.status();
264                if status.is_success() {
265                    let mut out = File::create(filepath)
266                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
267                    // Stream the response body directly to disk to avoid large allocations
268                    copy(&mut response, &mut out)
269                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
270                    println!("File downloaded to {filepath:?}");
271                    Ok(())
272                } else if status.is_server_error()
273                    || status.as_u16() == 429
274                    || status.as_u16() == 408
275                {
276                    println!("HTTP error {status}, retrying...");
277                    Err(DownloadError::Retryable(format!("HTTP {status}")))
278                } else {
279                    // Preserve existing error text used by tests
280                    Err(DownloadError::NonRetryable(format!(
281                        "Client error: HTTP {status}"
282                    )))
283                }
284            }
285            Err(e) => {
286                println!("Request failed: {e}");
287                Err(DownloadError::Retryable(e.to_string()))
288            }
289        }
290    };
291
292    let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
293
294    execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
295}
296
297fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
298    let mut file = File::open(filepath)?;
299    let mut ctx = Context::new(&digest::SHA256);
300    let mut buffer = [0u8; 4096];
301
302    loop {
303        let count = file.read(&mut buffer)?;
304        if count == 0 {
305            break;
306        }
307        ctx.update(&buffer[..count]);
308    }
309
310    let digest = ctx.finish();
311    Ok(hex::encode(digest.as_ref()))
312}
313
314fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
315    let file = File::open(checksums)?;
316    let reader = BufReader::new(file);
317    let checksums: Value = serde_json::from_reader(reader)?;
318
319    let filename = filepath.file_name().unwrap().to_str().unwrap();
320    if let Some(expected_checksum) = checksums.get(filename) {
321        let expected_checksum_str = expected_checksum.as_str().unwrap();
322        let expected_hash = expected_checksum_str
323            .strip_prefix("sha256:")
324            .unwrap_or(expected_checksum_str);
325        let calculated_checksum = calculate_sha256(filepath)?;
326        if expected_hash == calculated_checksum {
327            return Ok(true);
328        }
329    }
330
331    Ok(false)
332}
333
334fn update_sha256_checksums(
335    filepath: &Path,
336    checksums_file: &Path,
337    new_checksum: &str,
338) -> anyhow::Result<()> {
339    let checksums: Value = if checksums_file.exists() {
340        let file = File::open(checksums_file)?;
341        let reader = BufReader::new(file);
342        serde_json::from_reader(reader)?
343    } else {
344        serde_json::json!({})
345    };
346
347    let mut checksums_map = checksums.as_object().unwrap().clone();
348
349    // Add or update the checksum
350    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
351    let prefixed_checksum = format!("sha256:{new_checksum}");
352    checksums_map.insert(filename, Value::String(prefixed_checksum));
353
354    let file = OpenOptions::new()
355        .write(true)
356        .create(true)
357        .truncate(true)
358        .open(checksums_file)?;
359    let writer = BufWriter::new(file);
360    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
361
362    Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367    use std::{
368        fs,
369        io::{BufWriter, Write},
370        net::SocketAddr,
371        sync::{
372            Arc,
373            atomic::{AtomicUsize, Ordering},
374        },
375    };
376
377    use axum::{Router, http::StatusCode, routing::get, serve};
378    use rstest::*;
379    use serde_json::{json, to_writer};
380    use tempfile::TempDir;
381    use tokio::{
382        net::TcpListener,
383        task,
384        time::{Duration, sleep},
385    };
386
387    use super::*;
388
389    /// Creates a fast, deterministic retry config for tests.
390    /// Uses very short delays to make tests run quickly without introducing flakiness.
391    fn test_retry_config() -> RetryConfig {
392        RetryConfig {
393            max_retries: 5,
394            initial_delay_ms: 10,
395            max_delay_ms: 50,
396            backoff_factor: 2.0,
397            jitter_ms: 5,
398            operation_timeout_ms: Some(500),
399            immediate_first: false,
400            max_elapsed_ms: Some(2000),
401        }
402    }
403
404    async fn setup_test_server(
405        server_content: Option<String>,
406        status_code: StatusCode,
407    ) -> SocketAddr {
408        let server_content = Arc::new(server_content);
409        let server_content_clone = server_content.clone();
410        let app = Router::new().route(
411            "/testfile.txt",
412            get(move || {
413                let server_content = server_content_clone.clone();
414                async move {
415                    let response_body = match &*server_content {
416                        Some(content) => content.clone(),
417                        None => "File not found".to_string(),
418                    };
419                    (status_code, response_body)
420                }
421            }),
422        );
423
424        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
425        let addr = listener.local_addr().unwrap();
426        let server = serve(listener, app);
427
428        task::spawn(async move {
429            if let Err(e) = server.await {
430                eprintln!("server error: {e}");
431            }
432        });
433
434        sleep(Duration::from_millis(100)).await;
435
436        addr
437    }
438
439    #[tokio::test]
440    async fn test_file_already_exists() {
441        let temp_dir = TempDir::new().unwrap();
442        let file_path = temp_dir.path().join("testfile.txt");
443        fs::write(&file_path, "Existing file content").unwrap();
444
445        let url = "http://example.com/testfile.txt".to_string();
446        let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
447
448        assert!(result.is_ok());
449        let content = fs::read_to_string(&file_path).unwrap();
450        assert_eq!(content, "Existing file content");
451    }
452
453    #[tokio::test]
454    async fn test_download_file_success() {
455        let temp_dir = TempDir::new().unwrap();
456        let filepath = temp_dir.path().join("testfile.txt");
457        let filepath_clone = filepath.clone();
458
459        let server_content = "Server file content".to_string();
460        let status_code = StatusCode::OK;
461        let addr = setup_test_server(Some(server_content.clone()), status_code).await;
462        let url = format!("http://{addr}/testfile.txt");
463
464        let result = tokio::task::spawn_blocking(move || {
465            ensure_file_exists_or_download_http_with_config(
466                &filepath_clone,
467                &url,
468                None,
469                5,
470                Some(test_retry_config()),
471                Some(0),
472            )
473        })
474        .await
475        .unwrap();
476
477        assert!(result.is_ok());
478        let content = fs::read_to_string(&filepath).unwrap();
479        assert_eq!(content, server_content);
480    }
481
482    #[tokio::test]
483    async fn test_download_file_not_found() {
484        let temp_dir = TempDir::new().unwrap();
485        let file_path = temp_dir.path().join("testfile.txt");
486
487        let server_content = None;
488        let status_code = StatusCode::NOT_FOUND;
489        let addr = setup_test_server(server_content, status_code).await;
490        let url = format!("http://{addr}/testfile.txt");
491
492        let result = tokio::task::spawn_blocking(move || {
493            ensure_file_exists_or_download_http_with_config(
494                &file_path,
495                &url,
496                None,
497                1,
498                Some(test_retry_config()),
499                Some(0),
500            )
501        })
502        .await
503        .unwrap();
504
505        assert!(result.is_err());
506        let err_msg = format!("{}", result.unwrap_err());
507        assert!(
508            err_msg.contains("Client error: HTTP"),
509            "Unexpected error message: {err_msg}"
510        );
511    }
512
513    #[tokio::test]
514    async fn test_network_error() {
515        let temp_dir = TempDir::new().unwrap();
516        let file_path = temp_dir.path().join("testfile.txt");
517
518        // Use an unreachable address to simulate a network error
519        let url = "http://127.0.0.1:0/testfile.txt".to_string();
520
521        let result = tokio::task::spawn_blocking(move || {
522            ensure_file_exists_or_download_http_with_config(
523                &file_path,
524                &url,
525                None,
526                2,
527                Some(test_retry_config()),
528                Some(0),
529            )
530        })
531        .await
532        .unwrap();
533
534        assert!(result.is_err());
535        let err_msg = format!("{}", result.unwrap_err());
536        assert!(
537            err_msg.contains("error"),
538            "Unexpected error message: {err_msg}"
539        );
540    }
541
542    #[tokio::test]
543    async fn test_retry_then_success_on_500() {
544        let temp_dir = TempDir::new().unwrap();
545        let filepath = temp_dir.path().join("testfile.txt");
546        let filepath_clone = filepath.clone();
547
548        let counter = Arc::new(AtomicUsize::new(0));
549        let counter_clone = counter.clone();
550
551        let app = Router::new().route(
552            "/testfile.txt",
553            get(move || {
554                let c = counter_clone.clone();
555                async move {
556                    let n = c.fetch_add(1, Ordering::SeqCst);
557                    if n < 2 {
558                        (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
559                    } else {
560                        (StatusCode::OK, "eventual success")
561                    }
562                }
563            }),
564        );
565
566        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
567        let addr = listener.local_addr().unwrap();
568        let server = serve(listener, app);
569        task::spawn(async move {
570            let _ = server.await;
571        });
572        sleep(Duration::from_millis(100)).await;
573
574        let url = format!("http://{addr}/testfile.txt");
575
576        let result = tokio::task::spawn_blocking(move || {
577            ensure_file_exists_or_download_http_with_config(
578                &filepath_clone,
579                &url,
580                None,
581                5,
582                Some(test_retry_config()),
583                Some(0),
584            )
585        })
586        .await
587        .unwrap();
588
589        assert!(result.is_ok());
590        let content = std::fs::read_to_string(&filepath).unwrap();
591        assert_eq!(content, "eventual success");
592        assert!(counter.load(Ordering::SeqCst) >= 2);
593    }
594
595    #[tokio::test]
596    async fn test_retry_then_success_on_429() {
597        let temp_dir = TempDir::new().unwrap();
598        let filepath = temp_dir.path().join("testfile.txt");
599        let filepath_clone = filepath.clone();
600
601        let counter = Arc::new(AtomicUsize::new(0));
602        let counter_clone = counter.clone();
603
604        let app = Router::new().route(
605            "/testfile.txt",
606            get(move || {
607                let c = counter_clone.clone();
608                async move {
609                    let n = c.fetch_add(1, Ordering::SeqCst);
610                    if n < 1 {
611                        (StatusCode::TOO_MANY_REQUESTS, "rate limited")
612                    } else {
613                        (StatusCode::OK, "ok after retry")
614                    }
615                }
616            }),
617        );
618
619        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
620        let addr = listener.local_addr().unwrap();
621        let server = serve(listener, app);
622        task::spawn(async move {
623            let _ = server.await;
624        });
625        sleep(Duration::from_millis(100)).await;
626
627        let url = format!("http://{addr}/testfile.txt");
628
629        let result = tokio::task::spawn_blocking(move || {
630            ensure_file_exists_or_download_http_with_config(
631                &filepath_clone,
632                &url,
633                None,
634                5,
635                Some(test_retry_config()),
636                Some(0),
637            )
638        })
639        .await
640        .unwrap();
641
642        assert!(result.is_ok());
643        let content = std::fs::read_to_string(&filepath).unwrap();
644        assert_eq!(content, "ok after retry");
645        assert!(counter.load(Ordering::SeqCst) >= 2);
646    }
647
648    #[tokio::test]
649    async fn test_no_retry_on_404() {
650        let temp_dir = TempDir::new().unwrap();
651        let filepath = temp_dir.path().join("testfile.txt");
652        let filepath_clone = filepath.clone();
653
654        let counter = Arc::new(AtomicUsize::new(0));
655        let counter_clone = counter.clone();
656
657        let app = Router::new().route(
658            "/testfile.txt",
659            get(move || {
660                let c = counter_clone.clone();
661                async move {
662                    c.fetch_add(1, Ordering::SeqCst);
663                    (StatusCode::NOT_FOUND, "missing")
664                }
665            }),
666        );
667
668        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
669        let addr = listener.local_addr().unwrap();
670        let server = serve(listener, app);
671        task::spawn(async move {
672            let _ = server.await;
673        });
674        sleep(Duration::from_millis(100)).await;
675
676        let url = format!("http://{addr}/testfile.txt");
677
678        let result = tokio::task::spawn_blocking(move || {
679            ensure_file_exists_or_download_http_with_config(
680                &filepath_clone,
681                &url,
682                None,
683                5,
684                Some(test_retry_config()),
685                Some(0),
686            )
687        })
688        .await
689        .unwrap();
690
691        assert!(result.is_err());
692        assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
693    }
694
695    #[rstest]
696    #[allow(clippy::panic_in_result_fn)]
697    fn test_calculate_sha256() -> anyhow::Result<()> {
698        let temp_dir = TempDir::new()?;
699        let test_file_path = temp_dir.path().join("test_file.txt");
700        let mut test_file = File::create(&test_file_path)?;
701        let content = b"Hello, world!";
702        test_file.write_all(content)?;
703
704        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
705        let calculated_hash = calculate_sha256(&test_file_path)?;
706
707        assert_eq!(calculated_hash, expected_hash);
708        Ok(())
709    }
710
711    #[rstest]
712    #[allow(clippy::panic_in_result_fn)]
713    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
714        let temp_dir = TempDir::new()?;
715        let test_file_path = temp_dir.path().join("test_file.txt");
716        let mut test_file = File::create(&test_file_path)?;
717        let content = b"Hello, world!";
718        test_file.write_all(content)?;
719
720        let calculated_checksum = calculate_sha256(&test_file_path)?;
721
722        // Create checksums.json containing the checksum
723        let checksums_path = temp_dir.path().join("checksums.json");
724        let checksums_data = json!({
725            "test_file.txt": format!("sha256:{}", calculated_checksum)
726        });
727        let checksums_file = File::create(&checksums_path)?;
728        let writer = BufWriter::new(checksums_file);
729        to_writer(writer, &checksums_data)?;
730
731        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
732        assert!(is_valid, "The checksum should be valid");
733        Ok(())
734    }
735}