nautilus_test_kit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    fs::{File, OpenOptions},
18    io::{copy, BufReader, BufWriter, Read},
19    path::Path,
20};
21
22use reqwest::blocking::Client;
23use ring::digest;
24use serde_json::Value;
25
26/// Ensures that a file exists at the specified path by downloading it if necessary.
27///
28/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
29/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
30/// the checksum is invalid or missing, the function updates the checksums file with the correct
31/// hash for the existing file without redownloading it.
32///
33/// If the file does not exist, it downloads the file from the specified `url` and updates the
34/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
35pub fn ensure_file_exists_or_download_http(
36    filepath: &Path,
37    url: &str,
38    checksums: Option<&Path>,
39) -> anyhow::Result<()> {
40    if filepath.exists() {
41        println!("File already exists: {}", filepath.display());
42
43        if let Some(checksums_file) = checksums {
44            if verify_sha256_checksum(filepath, checksums_file)? {
45                println!("File is valid");
46                return Ok(());
47            } else {
48                let new_checksum = calculate_sha256(filepath)?;
49                println!("Adding checksum for existing file: {new_checksum}");
50                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
51                return Ok(());
52            }
53        } else {
54            return Ok(());
55        }
56    }
57
58    download_file(filepath, url)?;
59
60    if let Some(checksums_file) = checksums {
61        let new_checksum = calculate_sha256(filepath)?;
62        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
63    }
64
65    Ok(())
66}
67
68fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
69    println!("Downloading file from {url} to {}", filepath.display());
70
71    if let Some(parent) = filepath.parent() {
72        std::fs::create_dir_all(parent)?;
73    }
74
75    let mut response = Client::new().get(url).send()?;
76    if !response.status().is_success() {
77        anyhow::bail!("Failed to download file: HTTP {}", response.status());
78    }
79
80    let mut out = File::create(filepath)?;
81    copy(&mut response, &mut out)?;
82
83    println!("File downloaded to {}", filepath.display());
84    Ok(())
85}
86
87fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
88    let mut file = File::open(filepath)?;
89    let mut context = digest::Context::new(&digest::SHA256);
90    let mut buffer = [0; 4096];
91
92    loop {
93        let count = file.read(&mut buffer)?;
94        if count == 0 {
95            break;
96        }
97        context.update(&buffer[..count]);
98    }
99
100    let digest = context.finish();
101    Ok(hex::encode(digest.as_ref()))
102}
103
104fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
105    let file = File::open(checksums)?;
106    let reader = BufReader::new(file);
107    let checksums: Value = serde_json::from_reader(reader)?;
108
109    let filename = filepath.file_name().unwrap().to_str().unwrap();
110    if let Some(expected_checksum) = checksums.get(filename) {
111        let expected_checksum_str = expected_checksum.as_str().unwrap();
112        let expected_hash = expected_checksum_str
113            .strip_prefix("sha256:")
114            .unwrap_or(expected_checksum_str);
115        let calculated_checksum = calculate_sha256(filepath)?;
116        if expected_hash == calculated_checksum {
117            return Ok(true);
118        }
119    }
120
121    Ok(false)
122}
123
124fn update_sha256_checksums(
125    filepath: &Path,
126    checksums_file: &Path,
127    new_checksum: &str,
128) -> anyhow::Result<()> {
129    let checksums: Value = if checksums_file.exists() {
130        let file = File::open(checksums_file)?;
131        let reader = BufReader::new(file);
132        serde_json::from_reader(reader)?
133    } else {
134        serde_json::json!({})
135    };
136
137    let mut checksums_map = checksums.as_object().unwrap().clone();
138
139    // Add or update the checksum
140    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
141    let prefixed_checksum = format!("sha256:{new_checksum}");
142    checksums_map.insert(filename, Value::String(prefixed_checksum));
143
144    let file = OpenOptions::new()
145        .write(true)
146        .create(true)
147        .truncate(true)
148        .open(checksums_file)?;
149    let writer = BufWriter::new(file);
150    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
151
152    Ok(())
153}
154
155////////////////////////////////////////////////////////////////////////////////
156// Tests
157////////////////////////////////////////////////////////////////////////////////
158#[cfg(test)]
159mod tests {
160    use std::{
161        fs,
162        io::{BufWriter, Write},
163        net::SocketAddr,
164        sync::Arc,
165    };
166
167    use axum::{http::StatusCode, routing::get, serve, Router};
168    use rstest::*;
169    use serde_json::{json, to_writer};
170    use tempfile::TempDir;
171    use tokio::{
172        net::TcpListener,
173        task,
174        time::{sleep, Duration},
175    };
176
177    use super::*;
178
179    async fn setup_test_server(
180        server_content: Option<String>,
181        status_code: StatusCode,
182    ) -> SocketAddr {
183        let server_content = Arc::new(server_content);
184        let server_content_clone = server_content.clone();
185        let app = Router::new().route(
186            "/testfile.txt",
187            get(move || {
188                let server_content = server_content_clone.clone();
189                async move {
190                    let response_body = match &*server_content {
191                        Some(content) => content.clone(),
192                        None => "File not found".to_string(),
193                    };
194                    (status_code, response_body)
195                }
196            }),
197        );
198
199        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
200        let addr = listener.local_addr().unwrap();
201        let server = serve(listener, app);
202
203        task::spawn(async move {
204            if let Err(e) = server.await {
205                eprintln!("server error: {e}");
206            }
207        });
208
209        sleep(Duration::from_millis(100)).await;
210
211        addr
212    }
213
214    #[tokio::test]
215    async fn test_file_already_exists() {
216        let temp_dir = TempDir::new().unwrap();
217        let file_path = temp_dir.path().join("testfile.txt");
218        fs::write(&file_path, "Existing file content").unwrap();
219
220        let url = "http://example.com/testfile.txt".to_string();
221        let result = ensure_file_exists_or_download_http(&file_path, &url, None);
222
223        assert!(result.is_ok());
224        let content = fs::read_to_string(&file_path).unwrap();
225        assert_eq!(content, "Existing file content");
226    }
227
228    #[tokio::test]
229    async fn test_download_file_success() {
230        let temp_dir = TempDir::new().unwrap();
231        let filepath = temp_dir.path().join("testfile.txt");
232        let filepath_clone = filepath.clone();
233
234        let server_content = Some("Server file content".to_string());
235        let status_code = StatusCode::OK;
236        let addr = setup_test_server(server_content.clone(), status_code).await;
237        let url = format!("http://{addr}/testfile.txt");
238
239        let result = tokio::task::spawn_blocking(move || {
240            ensure_file_exists_or_download_http(&filepath_clone, &url, None)
241        })
242        .await
243        .unwrap();
244
245        assert!(result.is_ok());
246        let content = fs::read_to_string(&filepath).unwrap();
247        assert_eq!(content, server_content.unwrap());
248    }
249
250    #[tokio::test]
251    async fn test_download_file_not_found() {
252        let temp_dir = TempDir::new().unwrap();
253        let file_path = temp_dir.path().join("testfile.txt");
254
255        let server_content = None;
256        let status_code = StatusCode::NOT_FOUND;
257        let addr = setup_test_server(server_content, status_code).await;
258        let url = format!("http://{addr}/testfile.txt");
259
260        let result = tokio::task::spawn_blocking(move || {
261            ensure_file_exists_or_download_http(&file_path, &url, None)
262        })
263        .await
264        .unwrap();
265
266        assert!(result.is_err());
267        let err_msg = format!("{}", result.unwrap_err());
268        assert!(
269            err_msg.contains("Failed to download file"),
270            "Unexpected error message: {err_msg}"
271        );
272    }
273
274    #[tokio::test]
275    async fn test_network_error() {
276        let temp_dir = TempDir::new().unwrap();
277        let file_path = temp_dir.path().join("testfile.txt");
278
279        // Use an unreachable address to simulate a network error
280        let url = "http://127.0.0.1:0/testfile.txt".to_string();
281
282        let result = tokio::task::spawn_blocking(move || {
283            ensure_file_exists_or_download_http(&file_path, &url, None)
284        })
285        .await
286        .unwrap();
287
288        assert!(result.is_err());
289        let err_msg = format!("{}", result.unwrap_err());
290        assert!(
291            err_msg.contains("error"),
292            "Unexpected error message: {err_msg}"
293        );
294    }
295
296    #[rstest]
297    fn test_calculate_sha256() -> anyhow::Result<()> {
298        let temp_dir = TempDir::new()?;
299        let test_file_path = temp_dir.path().join("test_file.txt");
300        let mut test_file = File::create(&test_file_path)?;
301        let content = b"Hello, world!";
302        test_file.write_all(content)?;
303
304        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
305        let calculated_hash = calculate_sha256(&test_file_path)?;
306
307        assert_eq!(calculated_hash, expected_hash);
308        Ok(())
309    }
310
311    #[rstest]
312    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
313        let temp_dir = TempDir::new()?;
314        let test_file_path = temp_dir.path().join("test_file.txt");
315        let mut test_file = File::create(&test_file_path)?;
316        let content = b"Hello, world!";
317        test_file.write_all(content)?;
318
319        let calculated_checksum = calculate_sha256(&test_file_path)?;
320
321        // Create checksums.json containing the checksum
322        let checksums_path = temp_dir.path().join("checksums.json");
323        let checksums_data = json!({
324            "test_file.txt": format!("sha256:{}", calculated_checksum)
325        });
326        let checksums_file = File::create(&checksums_path)?;
327        let writer = BufWriter::new(checksums_file);
328        to_writer(writer, &checksums_data)?;
329
330        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
331        assert!(is_valid, "The checksum should be valid");
332        Ok(())
333    }
334}