nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    fs::{File, OpenOptions},
18    io::{BufReader, BufWriter, Read, copy},
19    path::Path,
20    thread::sleep,
21    time::Duration,
22};
23
24use aws_lc_rs::digest::{self, Context};
25use rand::{Rng, rng};
26use reqwest::blocking::Client;
27use serde_json::Value;
28
29/// Ensures that a file exists at the specified path by downloading it if necessary.
30///
31/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
32/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
33/// the checksum is invalid or missing, the function updates the checksums file with the correct
34/// hash for the existing file without redownloading it.
35///
36/// If the file does not exist, it downloads the file from the specified `url` and updates the
37/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
38///
39/// # Errors
40///
41/// Returns an error if:
42/// - The HTTP request cannot be sent or returns a non-success status code.
43/// - Any I/O operation fails during file creation, reading, or writing.
44/// - Checksum verification or JSON parsing fails.
45pub fn ensure_file_exists_or_download_http(
46    filepath: &Path,
47    url: &str,
48    checksums: Option<&Path>,
49) -> anyhow::Result<()> {
50    if filepath.exists() {
51        println!("File already exists: {filepath:?}");
52
53        if let Some(checksums_file) = checksums {
54            if verify_sha256_checksum(filepath, checksums_file)? {
55                println!("File is valid");
56                return Ok(());
57            } else {
58                let new_checksum = calculate_sha256(filepath)?;
59                println!("Adding checksum for existing file: {new_checksum}");
60                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
61                return Ok(());
62            }
63        }
64        return Ok(());
65    }
66
67    // Add a small random delay (100–600 ms) to avoid bursting the remote server when
68    // many tests start concurrently. A true random jitter is preferred over a
69    // deterministic hash to prevent synchronized traffic spikes.
70    let jitter_delay = {
71        let mut r = rng();
72        Duration::from_millis(r.random_range(100..=600))
73    };
74    sleep(jitter_delay);
75
76    download_file(filepath, url)?;
77
78    if let Some(checksums_file) = checksums {
79        let new_checksum = calculate_sha256(filepath)?;
80        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
81    }
82
83    Ok(())
84}
85
86fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
87    const MAX_RETRIES: usize = 3;
88    const BASE_DELAY_MS: u64 = 1000;
89    const TIMEOUT_SECONDS: u64 = 30;
90
91    println!("Downloading file from {url} to {filepath:?}");
92
93    if let Some(parent) = filepath.parent() {
94        std::fs::create_dir_all(parent)?;
95    }
96
97    let client = Client::builder()
98        .timeout(Duration::from_secs(TIMEOUT_SECONDS))
99        .build()?;
100
101    let mut last_error = None;
102
103    for attempt in 0..MAX_RETRIES {
104        if attempt > 0 {
105            // Exponential backoff with additional random jitter up to BASE_DELAY_MS
106            let exponential_delay_ms = BASE_DELAY_MS * 2_u64.pow(attempt as u32 - 1);
107            let jitter_ms = rng().random_range(0..BASE_DELAY_MS);
108            let delay = Duration::from_millis(exponential_delay_ms + jitter_ms);
109            println!(
110                "Retrying download in {delay:?} (attempt {}/{MAX_RETRIES})",
111                attempt + 1
112            );
113            sleep(delay);
114        }
115
116        match client.get(url).send() {
117            Ok(mut response) => {
118                let status = response.status();
119                if status.is_success() {
120                    let mut out = File::create(filepath)?;
121                    // Stream the response body directly to disk to avoid large allocations
122                    copy(&mut response, &mut out)?;
123                    println!("File downloaded to {filepath:?}");
124                    return Ok(());
125                } else if status.is_server_error() {
126                    // Retry on 5xx server errors
127                    println!("Server error (HTTP {status}), retrying...");
128                    last_error = Some(anyhow::anyhow!("Server error: HTTP {status}"));
129                    continue;
130                }
131                // 4xx errors are considered client side and not retried
132                anyhow::bail!("Client error: HTTP {status}");
133            }
134            Err(e) => {
135                println!("Request failed: {e}");
136                last_error = Some(anyhow::anyhow!("Request failed: {e}"));
137                continue;
138            }
139        }
140    }
141
142    Err(last_error
143        .unwrap_or_else(|| anyhow::anyhow!("Download failed after {MAX_RETRIES} attempts")))
144}
145
146fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
147    let mut file = File::open(filepath)?;
148    let mut ctx = Context::new(&digest::SHA256);
149    let mut buffer = [0u8; 4096];
150
151    loop {
152        let count = file.read(&mut buffer)?;
153        if count == 0 {
154            break;
155        }
156        ctx.update(&buffer[..count]);
157    }
158
159    let digest = ctx.finish();
160    Ok(hex::encode(digest.as_ref()))
161}
162
163fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
164    let file = File::open(checksums)?;
165    let reader = BufReader::new(file);
166    let checksums: Value = serde_json::from_reader(reader)?;
167
168    let filename = filepath.file_name().unwrap().to_str().unwrap();
169    if let Some(expected_checksum) = checksums.get(filename) {
170        let expected_checksum_str = expected_checksum.as_str().unwrap();
171        let expected_hash = expected_checksum_str
172            .strip_prefix("sha256:")
173            .unwrap_or(expected_checksum_str);
174        let calculated_checksum = calculate_sha256(filepath)?;
175        if expected_hash == calculated_checksum {
176            return Ok(true);
177        }
178    }
179
180    Ok(false)
181}
182
183fn update_sha256_checksums(
184    filepath: &Path,
185    checksums_file: &Path,
186    new_checksum: &str,
187) -> anyhow::Result<()> {
188    let checksums: Value = if checksums_file.exists() {
189        let file = File::open(checksums_file)?;
190        let reader = BufReader::new(file);
191        serde_json::from_reader(reader)?
192    } else {
193        serde_json::json!({})
194    };
195
196    let mut checksums_map = checksums.as_object().unwrap().clone();
197
198    // Add or update the checksum
199    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
200    let prefixed_checksum = format!("sha256:{new_checksum}");
201    checksums_map.insert(filename, Value::String(prefixed_checksum));
202
203    let file = OpenOptions::new()
204        .write(true)
205        .create(true)
206        .truncate(true)
207        .open(checksums_file)?;
208    let writer = BufWriter::new(file);
209    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
210
211    Ok(())
212}
213
214////////////////////////////////////////////////////////////////////////////////
215// Tests
216////////////////////////////////////////////////////////////////////////////////
217#[cfg(test)]
218mod tests {
219    use std::{
220        fs,
221        io::{BufWriter, Write},
222        net::SocketAddr,
223        sync::Arc,
224    };
225
226    use axum::{Router, http::StatusCode, routing::get, serve};
227    use rstest::*;
228    use serde_json::{json, to_writer};
229    use tempfile::TempDir;
230    use tokio::{
231        net::TcpListener,
232        task,
233        time::{Duration, sleep},
234    };
235
236    use super::*;
237
238    async fn setup_test_server(
239        server_content: Option<String>,
240        status_code: StatusCode,
241    ) -> SocketAddr {
242        let server_content = Arc::new(server_content);
243        let server_content_clone = server_content.clone();
244        let app = Router::new().route(
245            "/testfile.txt",
246            get(move || {
247                let server_content = server_content_clone.clone();
248                async move {
249                    let response_body = match &*server_content {
250                        Some(content) => content.clone(),
251                        None => "File not found".to_string(),
252                    };
253                    (status_code, response_body)
254                }
255            }),
256        );
257
258        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
259        let addr = listener.local_addr().unwrap();
260        let server = serve(listener, app);
261
262        task::spawn(async move {
263            if let Err(e) = server.await {
264                eprintln!("server error: {e}");
265            }
266        });
267
268        sleep(Duration::from_millis(100)).await;
269
270        addr
271    }
272
273    #[tokio::test]
274    async fn test_file_already_exists() {
275        let temp_dir = TempDir::new().unwrap();
276        let file_path = temp_dir.path().join("testfile.txt");
277        fs::write(&file_path, "Existing file content").unwrap();
278
279        let url = "http://example.com/testfile.txt".to_string();
280        let result = ensure_file_exists_or_download_http(&file_path, &url, None);
281
282        assert!(result.is_ok());
283        let content = fs::read_to_string(&file_path).unwrap();
284        assert_eq!(content, "Existing file content");
285    }
286
287    #[tokio::test]
288    async fn test_download_file_success() {
289        let temp_dir = TempDir::new().unwrap();
290        let filepath = temp_dir.path().join("testfile.txt");
291        let filepath_clone = filepath.clone();
292
293        let server_content = "Server file content".to_string();
294        let status_code = StatusCode::OK;
295        let addr = setup_test_server(Some(server_content.clone()), status_code).await;
296        let url = format!("http://{addr}/testfile.txt");
297
298        let result = tokio::task::spawn_blocking(move || {
299            ensure_file_exists_or_download_http(&filepath_clone, &url, None)
300        })
301        .await
302        .unwrap();
303
304        assert!(result.is_ok());
305        let content = fs::read_to_string(&filepath).unwrap();
306        assert_eq!(content, server_content);
307    }
308
309    #[tokio::test]
310    async fn test_download_file_not_found() {
311        let temp_dir = TempDir::new().unwrap();
312        let file_path = temp_dir.path().join("testfile.txt");
313
314        let server_content = None;
315        let status_code = StatusCode::NOT_FOUND;
316        let addr = setup_test_server(server_content, status_code).await;
317        let url = format!("http://{addr}/testfile.txt");
318
319        let result = tokio::task::spawn_blocking(move || {
320            ensure_file_exists_or_download_http(&file_path, &url, None)
321        })
322        .await
323        .unwrap();
324
325        assert!(result.is_err());
326        let err_msg = format!("{}", result.unwrap_err());
327        assert!(
328            err_msg.contains("Client error: HTTP"),
329            "Unexpected error message: {err_msg}"
330        );
331    }
332
333    #[tokio::test]
334    async fn test_network_error() {
335        let temp_dir = TempDir::new().unwrap();
336        let file_path = temp_dir.path().join("testfile.txt");
337
338        // Use an unreachable address to simulate a network error
339        let url = "http://127.0.0.1:0/testfile.txt".to_string();
340
341        let result = tokio::task::spawn_blocking(move || {
342            ensure_file_exists_or_download_http(&file_path, &url, None)
343        })
344        .await
345        .unwrap();
346
347        assert!(result.is_err());
348        let err_msg = format!("{}", result.unwrap_err());
349        assert!(
350            err_msg.contains("error"),
351            "Unexpected error message: {err_msg}"
352        );
353    }
354
355    #[rstest]
356    fn test_calculate_sha256() -> anyhow::Result<()> {
357        let temp_dir = TempDir::new()?;
358        let test_file_path = temp_dir.path().join("test_file.txt");
359        let mut test_file = File::create(&test_file_path)?;
360        let content = b"Hello, world!";
361        test_file.write_all(content)?;
362
363        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
364        let calculated_hash = calculate_sha256(&test_file_path)?;
365
366        assert_eq!(calculated_hash, expected_hash);
367        Ok(())
368    }
369
370    #[rstest]
371    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
372        let temp_dir = TempDir::new()?;
373        let test_file_path = temp_dir.path().join("test_file.txt");
374        let mut test_file = File::create(&test_file_path)?;
375        let content = b"Hello, world!";
376        test_file.write_all(content)?;
377
378        let calculated_checksum = calculate_sha256(&test_file_path)?;
379
380        // Create checksums.json containing the checksum
381        let checksums_path = temp_dir.path().join("checksums.json");
382        let checksums_data = json!({
383            "test_file.txt": format!("sha256:{}", calculated_checksum)
384        });
385        let checksums_file = File::create(&checksums_path)?;
386        let writer = BufWriter::new(checksums_file);
387        to_writer(writer, &checksums_data)?;
388
389        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
390        assert!(is_valid, "The checksum should be valid");
391        Ok(())
392    }
393}