nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    cmp,
18    fs::{File, OpenOptions},
19    io::{BufReader, BufWriter, Read, copy},
20    path::Path,
21    thread::sleep,
22    time::{Duration, Instant},
23};
24
25use aws_lc_rs::digest::{self, Context};
26use nautilus_network::retry::RetryConfig;
27use rand::{Rng, rng};
28use reqwest::blocking::Client;
29use serde_json::Value;
30
31#[derive(Debug)]
32enum DownloadError {
33    Retryable(String),
34    NonRetryable(String),
35}
36
37impl std::fmt::Display for DownloadError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Retryable(msg) => write!(f, "Retryable error: {msg}"),
41            Self::NonRetryable(msg) => write!(f, "Non-retryable error: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for DownloadError {}
47
48fn execute_with_retry_blocking<T, E, F>(
49    config: &RetryConfig,
50    mut op: F,
51    should_retry: impl Fn(&E) -> bool,
52) -> Result<T, E>
53where
54    E: std::error::Error,
55    F: FnMut() -> Result<T, E>,
56{
57    let start = Instant::now();
58    let mut delay = Duration::from_millis(config.initial_delay_ms);
59
60    for attempt in 0..=config.max_retries {
61        if attempt > 0 && !config.immediate_first {
62            let jitter = rng().random_range(0..=config.jitter_ms);
63            let sleep_for = delay + Duration::from_millis(jitter);
64            sleep(sleep_for);
65            let next = (delay.as_millis() as f64 * config.backoff_factor) as u64;
66            delay = cmp::min(
67                Duration::from_millis(next),
68                Duration::from_millis(config.max_delay_ms),
69            );
70        }
71
72        if let Some(max_total) = config.max_elapsed_ms
73            && start.elapsed() >= Duration::from_millis(max_total)
74        {
75            break;
76        }
77
78        match op() {
79            Ok(v) => return Ok(v),
80            Err(e) if attempt < config.max_retries && should_retry(&e) => continue,
81            Err(e) => return Err(e),
82        }
83    }
84
85    op()
86}
87
88/// Ensures that a file exists at the specified path by downloading it if necessary.
89///
90/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
91/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
92/// the checksum is invalid or missing, the function updates the checksums file with the correct
93/// hash for the existing file without redownloading it.
94///
95/// If the file does not exist, it downloads the file from the specified `url` and updates the
96/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
97///
98/// The `timeout_secs` parameter specifies the timeout in seconds for the HTTP request.
99/// If `None` is provided, a default timeout of 30 seconds will be used.
100///
101/// # Errors
102///
103/// Returns an error if:
104/// - The HTTP request cannot be sent or returns a non-success status code.
105/// - Any I/O operation fails during file creation, reading, or writing.
106/// - Checksum verification or JSON parsing fails.
107pub fn ensure_file_exists_or_download_http(
108    filepath: &Path,
109    url: &str,
110    checksums: Option<&Path>,
111    timeout_secs: Option<u64>,
112) -> anyhow::Result<()> {
113    ensure_file_exists_or_download_http_with_timeout(
114        filepath,
115        url,
116        checksums,
117        timeout_secs.unwrap_or(30),
118    )
119}
120
121/// Ensures that a file exists at the specified path by downloading it if necessary, with a custom timeout.
122///
123/// # Errors
124///
125/// Returns an error if:
126/// - The HTTP request cannot be sent or returns a non-success status code after retries.
127/// - Any I/O operation fails during file creation, reading, or writing.
128/// - Checksum verification or JSON parsing fails.
129pub fn ensure_file_exists_or_download_http_with_timeout(
130    filepath: &Path,
131    url: &str,
132    checksums: Option<&Path>,
133    timeout_secs: u64,
134) -> anyhow::Result<()> {
135    if filepath.exists() {
136        println!("File already exists: {filepath:?}");
137
138        if let Some(checksums_file) = checksums {
139            if verify_sha256_checksum(filepath, checksums_file)? {
140                println!("File is valid");
141                return Ok(());
142            } else {
143                let new_checksum = calculate_sha256(filepath)?;
144                println!("Adding checksum for existing file: {new_checksum}");
145                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
146                return Ok(());
147            }
148        }
149        return Ok(());
150    }
151
152    // Add a small random delay (100–600 ms) to avoid bursting the remote server when
153    // many tests start concurrently. A true random jitter is preferred over a
154    // deterministic hash to prevent synchronized traffic spikes.
155    let jitter_delay = {
156        let mut r = rng();
157        Duration::from_millis(r.random_range(100..=600))
158    };
159    sleep(jitter_delay);
160
161    download_file(filepath, url, timeout_secs)?;
162
163    if let Some(checksums_file) = checksums {
164        let new_checksum = calculate_sha256(filepath)?;
165        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
166    }
167
168    Ok(())
169}
170
171fn download_file(filepath: &Path, url: &str, timeout_secs: u64) -> anyhow::Result<()> {
172    println!("Downloading file from {url} to {filepath:?}");
173
174    if let Some(parent) = filepath.parent() {
175        std::fs::create_dir_all(parent)?;
176    }
177
178    let client = Client::builder()
179        .timeout(Duration::from_secs(timeout_secs))
180        .build()?;
181
182    let max_retries = 5u32;
183    let op_timeout_ms = timeout_secs.saturating_mul(1000);
184    // Make the provided timeout a hard ceiling for total elapsed time.
185    // Split it across attempts (at least 1000 ms per attempt) and cap total at op_timeout_ms.
186    let per_attempt_ms = std::cmp::max(1000u64, op_timeout_ms / (max_retries as u64 + 1));
187    let cfg = RetryConfig {
188        max_retries,
189        initial_delay_ms: 1_000,
190        max_delay_ms: 10_000,
191        backoff_factor: 2.0,
192        jitter_ms: 1_000,
193        operation_timeout_ms: Some(per_attempt_ms),
194        immediate_first: false,
195        max_elapsed_ms: Some(op_timeout_ms),
196    };
197
198    let op = || -> Result<(), DownloadError> {
199        match client.get(url).send() {
200            Ok(mut response) => {
201                let status = response.status();
202                if status.is_success() {
203                    let mut out = File::create(filepath)
204                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
205                    // Stream the response body directly to disk to avoid large allocations
206                    copy(&mut response, &mut out)
207                        .map_err(|e| DownloadError::NonRetryable(e.to_string()))?;
208                    println!("File downloaded to {filepath:?}");
209                    Ok(())
210                } else if status.is_server_error()
211                    || status.as_u16() == 429
212                    || status.as_u16() == 408
213                {
214                    println!("HTTP error {status}, retrying...");
215                    Err(DownloadError::Retryable(format!("HTTP {status}")))
216                } else {
217                    // Preserve existing error text used by tests
218                    Err(DownloadError::NonRetryable(format!(
219                        "Client error: HTTP {status}"
220                    )))
221                }
222            }
223            Err(e) => {
224                println!("Request failed: {e}");
225                Err(DownloadError::Retryable(e.to_string()))
226            }
227        }
228    };
229
230    let should_retry = |e: &DownloadError| matches!(e, DownloadError::Retryable(_));
231
232    execute_with_retry_blocking(&cfg, op, should_retry).map_err(|e| anyhow::anyhow!(e.to_string()))
233}
234
235fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
236    let mut file = File::open(filepath)?;
237    let mut ctx = Context::new(&digest::SHA256);
238    let mut buffer = [0u8; 4096];
239
240    loop {
241        let count = file.read(&mut buffer)?;
242        if count == 0 {
243            break;
244        }
245        ctx.update(&buffer[..count]);
246    }
247
248    let digest = ctx.finish();
249    Ok(hex::encode(digest.as_ref()))
250}
251
252fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
253    let file = File::open(checksums)?;
254    let reader = BufReader::new(file);
255    let checksums: Value = serde_json::from_reader(reader)?;
256
257    let filename = filepath.file_name().unwrap().to_str().unwrap();
258    if let Some(expected_checksum) = checksums.get(filename) {
259        let expected_checksum_str = expected_checksum.as_str().unwrap();
260        let expected_hash = expected_checksum_str
261            .strip_prefix("sha256:")
262            .unwrap_or(expected_checksum_str);
263        let calculated_checksum = calculate_sha256(filepath)?;
264        if expected_hash == calculated_checksum {
265            return Ok(true);
266        }
267    }
268
269    Ok(false)
270}
271
272fn update_sha256_checksums(
273    filepath: &Path,
274    checksums_file: &Path,
275    new_checksum: &str,
276) -> anyhow::Result<()> {
277    let checksums: Value = if checksums_file.exists() {
278        let file = File::open(checksums_file)?;
279        let reader = BufReader::new(file);
280        serde_json::from_reader(reader)?
281    } else {
282        serde_json::json!({})
283    };
284
285    let mut checksums_map = checksums.as_object().unwrap().clone();
286
287    // Add or update the checksum
288    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
289    let prefixed_checksum = format!("sha256:{new_checksum}");
290    checksums_map.insert(filename, Value::String(prefixed_checksum));
291
292    let file = OpenOptions::new()
293        .write(true)
294        .create(true)
295        .truncate(true)
296        .open(checksums_file)?;
297    let writer = BufWriter::new(file);
298    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
299
300    Ok(())
301}
302
303////////////////////////////////////////////////////////////////////////////////
304// Tests
305////////////////////////////////////////////////////////////////////////////////
306#[cfg(test)]
307mod tests {
308    use std::{
309        fs,
310        io::{BufWriter, Write},
311        net::SocketAddr,
312        sync::{
313            Arc,
314            atomic::{AtomicUsize, Ordering},
315        },
316    };
317
318    use axum::{Router, http::StatusCode, routing::get, serve};
319    use rstest::*;
320    use serde_json::{json, to_writer};
321    use tempfile::TempDir;
322    use tokio::{
323        net::TcpListener,
324        task,
325        time::{Duration, sleep},
326    };
327
328    use super::*;
329
330    async fn setup_test_server(
331        server_content: Option<String>,
332        status_code: StatusCode,
333    ) -> SocketAddr {
334        let server_content = Arc::new(server_content);
335        let server_content_clone = server_content.clone();
336        let app = Router::new().route(
337            "/testfile.txt",
338            get(move || {
339                let server_content = server_content_clone.clone();
340                async move {
341                    let response_body = match &*server_content {
342                        Some(content) => content.clone(),
343                        None => "File not found".to_string(),
344                    };
345                    (status_code, response_body)
346                }
347            }),
348        );
349
350        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
351        let addr = listener.local_addr().unwrap();
352        let server = serve(listener, app);
353
354        task::spawn(async move {
355            if let Err(e) = server.await {
356                eprintln!("server error: {e}");
357            }
358        });
359
360        sleep(Duration::from_millis(100)).await;
361
362        addr
363    }
364
365    #[tokio::test]
366    async fn test_file_already_exists() {
367        let temp_dir = TempDir::new().unwrap();
368        let file_path = temp_dir.path().join("testfile.txt");
369        fs::write(&file_path, "Existing file content").unwrap();
370
371        let url = "http://example.com/testfile.txt".to_string();
372        let result = ensure_file_exists_or_download_http(&file_path, &url, None, Some(5));
373
374        assert!(result.is_ok());
375        let content = fs::read_to_string(&file_path).unwrap();
376        assert_eq!(content, "Existing file content");
377    }
378
379    #[tokio::test]
380    async fn test_download_file_success() {
381        let temp_dir = TempDir::new().unwrap();
382        let filepath = temp_dir.path().join("testfile.txt");
383        let filepath_clone = filepath.clone();
384
385        let server_content = "Server file content".to_string();
386        let status_code = StatusCode::OK;
387        let addr = setup_test_server(Some(server_content.clone()), status_code).await;
388        let url = format!("http://{addr}/testfile.txt");
389
390        let result = tokio::task::spawn_blocking(move || {
391            ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
392        })
393        .await
394        .unwrap();
395
396        assert!(result.is_ok());
397        let content = fs::read_to_string(&filepath).unwrap();
398        assert_eq!(content, server_content);
399    }
400
401    #[tokio::test]
402    async fn test_download_file_not_found() {
403        let temp_dir = TempDir::new().unwrap();
404        let file_path = temp_dir.path().join("testfile.txt");
405
406        let server_content = None;
407        let status_code = StatusCode::NOT_FOUND;
408        let addr = setup_test_server(server_content, status_code).await;
409        let url = format!("http://{addr}/testfile.txt");
410
411        let result = tokio::task::spawn_blocking(move || {
412            ensure_file_exists_or_download_http_with_timeout(&file_path, &url, None, 1)
413        })
414        .await
415        .unwrap();
416
417        assert!(result.is_err());
418        let err_msg = format!("{}", result.unwrap_err());
419        assert!(
420            err_msg.contains("Client error: HTTP"),
421            "Unexpected error message: {err_msg}"
422        );
423    }
424
425    #[tokio::test]
426    async fn test_network_error() {
427        let temp_dir = TempDir::new().unwrap();
428        let file_path = temp_dir.path().join("testfile.txt");
429
430        // Use an unreachable address to simulate a network error
431        let url = "http://127.0.0.1:0/testfile.txt".to_string();
432
433        let result = tokio::task::spawn_blocking(move || {
434            ensure_file_exists_or_download_http(&file_path, &url, None, Some(2))
435        })
436        .await
437        .unwrap();
438
439        assert!(result.is_err());
440        let err_msg = format!("{}", result.unwrap_err());
441        assert!(
442            err_msg.contains("error"),
443            "Unexpected error message: {err_msg}"
444        );
445    }
446
447    #[tokio::test]
448    async fn test_retry_then_success_on_500() {
449        let temp_dir = TempDir::new().unwrap();
450        let filepath = temp_dir.path().join("testfile.txt");
451        let filepath_clone = filepath.clone();
452
453        let counter = Arc::new(AtomicUsize::new(0));
454        let counter_clone = counter.clone();
455
456        let app = Router::new().route(
457            "/testfile.txt",
458            get(move || {
459                let c = counter_clone.clone();
460                async move {
461                    let n = c.fetch_add(1, Ordering::SeqCst);
462                    if n < 2 {
463                        (StatusCode::INTERNAL_SERVER_ERROR, "temporary error")
464                    } else {
465                        (StatusCode::OK, "eventual success")
466                    }
467                }
468            }),
469        );
470
471        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
472        let addr = listener.local_addr().unwrap();
473        let server = serve(listener, app);
474        task::spawn(async move {
475            let _ = server.await;
476        });
477        sleep(Duration::from_millis(100)).await;
478
479        let url = format!("http://{addr}/testfile.txt");
480        let result = tokio::task::spawn_blocking(move || {
481            ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
482        })
483        .await
484        .unwrap();
485
486        assert!(result.is_ok());
487        let content = std::fs::read_to_string(&filepath).unwrap();
488        assert_eq!(content, "eventual success");
489        assert!(counter.load(Ordering::SeqCst) >= 2);
490    }
491
492    #[tokio::test]
493    async fn test_retry_then_success_on_429() {
494        let temp_dir = TempDir::new().unwrap();
495        let filepath = temp_dir.path().join("testfile.txt");
496        let filepath_clone = filepath.clone();
497
498        let counter = Arc::new(AtomicUsize::new(0));
499        let counter_clone = counter.clone();
500
501        let app = Router::new().route(
502            "/testfile.txt",
503            get(move || {
504                let c = counter_clone.clone();
505                async move {
506                    let n = c.fetch_add(1, Ordering::SeqCst);
507                    if n < 1 {
508                        (StatusCode::TOO_MANY_REQUESTS, "rate limited")
509                    } else {
510                        (StatusCode::OK, "ok after retry")
511                    }
512                }
513            }),
514        );
515
516        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
517        let addr = listener.local_addr().unwrap();
518        let server = serve(listener, app);
519        task::spawn(async move {
520            let _ = server.await;
521        });
522        sleep(Duration::from_millis(100)).await;
523
524        let url = format!("http://{addr}/testfile.txt");
525        let result = tokio::task::spawn_blocking(move || {
526            ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
527        })
528        .await
529        .unwrap();
530
531        assert!(result.is_ok());
532        let content = std::fs::read_to_string(&filepath).unwrap();
533        assert_eq!(content, "ok after retry");
534        assert!(counter.load(Ordering::SeqCst) >= 2);
535    }
536
537    #[tokio::test]
538    async fn test_no_retry_on_404() {
539        let temp_dir = TempDir::new().unwrap();
540        let filepath = temp_dir.path().join("testfile.txt");
541        let filepath_clone = filepath.clone();
542
543        let counter = Arc::new(AtomicUsize::new(0));
544        let counter_clone = counter.clone();
545
546        let app = Router::new().route(
547            "/testfile.txt",
548            get(move || {
549                let c = counter_clone.clone();
550                async move {
551                    c.fetch_add(1, Ordering::SeqCst);
552                    (StatusCode::NOT_FOUND, "missing")
553                }
554            }),
555        );
556
557        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
558        let addr = listener.local_addr().unwrap();
559        let server = serve(listener, app);
560        task::spawn(async move {
561            let _ = server.await;
562        });
563        sleep(Duration::from_millis(100)).await;
564
565        let url = format!("http://{addr}/testfile.txt");
566        let result = tokio::task::spawn_blocking(move || {
567            ensure_file_exists_or_download_http(&filepath_clone, &url, None, Some(5))
568        })
569        .await
570        .unwrap();
571
572        assert!(result.is_err());
573        assert_eq!(counter.load(Ordering::SeqCst), 1, "should not retry on 404");
574    }
575
576    #[rstest]
577    fn test_calculate_sha256() -> anyhow::Result<()> {
578        let temp_dir = TempDir::new()?;
579        let test_file_path = temp_dir.path().join("test_file.txt");
580        let mut test_file = File::create(&test_file_path)?;
581        let content = b"Hello, world!";
582        test_file.write_all(content)?;
583
584        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
585        let calculated_hash = calculate_sha256(&test_file_path)?;
586
587        assert_eq!(calculated_hash, expected_hash);
588        Ok(())
589    }
590
591    #[rstest]
592    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
593        let temp_dir = TempDir::new()?;
594        let test_file_path = temp_dir.path().join("test_file.txt");
595        let mut test_file = File::create(&test_file_path)?;
596        let content = b"Hello, world!";
597        test_file.write_all(content)?;
598
599        let calculated_checksum = calculate_sha256(&test_file_path)?;
600
601        // Create checksums.json containing the checksum
602        let checksums_path = temp_dir.path().join("checksums.json");
603        let checksums_data = json!({
604            "test_file.txt": format!("sha256:{}", calculated_checksum)
605        });
606        let checksums_file = File::create(&checksums_path)?;
607        let writer = BufWriter::new(checksums_file);
608        to_writer(writer, &checksums_data)?;
609
610        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
611        assert!(is_valid, "The checksum should be valid");
612        Ok(())
613    }
614}