nautilus_test_kit/
files.rs1use std::{
17 fs::{File, OpenOptions},
18 io::{copy, BufReader, BufWriter, Read},
19 path::Path,
20};
21
22use reqwest::blocking::Client;
23use ring::digest;
24use serde_json::Value;
25
26pub fn ensure_file_exists_or_download_http(
36 filepath: &Path,
37 url: &str,
38 checksums: Option<&Path>,
39) -> anyhow::Result<()> {
40 if filepath.exists() {
41 println!("File already exists: {}", filepath.display());
42
43 if let Some(checksums_file) = checksums {
44 if verify_sha256_checksum(filepath, checksums_file)? {
45 println!("File is valid");
46 return Ok(());
47 } else {
48 let new_checksum = calculate_sha256(filepath)?;
49 println!("Adding checksum for existing file: {new_checksum}");
50 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
51 return Ok(());
52 }
53 } else {
54 return Ok(());
55 }
56 }
57
58 download_file(filepath, url)?;
59
60 if let Some(checksums_file) = checksums {
61 let new_checksum = calculate_sha256(filepath)?;
62 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
63 }
64
65 Ok(())
66}
67
68fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
69 println!("Downloading file from {url} to {}", filepath.display());
70
71 if let Some(parent) = filepath.parent() {
72 std::fs::create_dir_all(parent)?;
73 }
74
75 let mut response = Client::new().get(url).send()?;
76 if !response.status().is_success() {
77 anyhow::bail!("Failed to download file: HTTP {}", response.status());
78 }
79
80 let mut out = File::create(filepath)?;
81 copy(&mut response, &mut out)?;
82
83 println!("File downloaded to {}", filepath.display());
84 Ok(())
85}
86
87fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
88 let mut file = File::open(filepath)?;
89 let mut context = digest::Context::new(&digest::SHA256);
90 let mut buffer = [0; 4096];
91
92 loop {
93 let count = file.read(&mut buffer)?;
94 if count == 0 {
95 break;
96 }
97 context.update(&buffer[..count]);
98 }
99
100 let digest = context.finish();
101 Ok(hex::encode(digest.as_ref()))
102}
103
104fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
105 let file = File::open(checksums)?;
106 let reader = BufReader::new(file);
107 let checksums: Value = serde_json::from_reader(reader)?;
108
109 let filename = filepath.file_name().unwrap().to_str().unwrap();
110 if let Some(expected_checksum) = checksums.get(filename) {
111 let expected_checksum_str = expected_checksum.as_str().unwrap();
112 let expected_hash = expected_checksum_str
113 .strip_prefix("sha256:")
114 .unwrap_or(expected_checksum_str);
115 let calculated_checksum = calculate_sha256(filepath)?;
116 if expected_hash == calculated_checksum {
117 return Ok(true);
118 }
119 }
120
121 Ok(false)
122}
123
124fn update_sha256_checksums(
125 filepath: &Path,
126 checksums_file: &Path,
127 new_checksum: &str,
128) -> anyhow::Result<()> {
129 let checksums: Value = if checksums_file.exists() {
130 let file = File::open(checksums_file)?;
131 let reader = BufReader::new(file);
132 serde_json::from_reader(reader)?
133 } else {
134 serde_json::json!({})
135 };
136
137 let mut checksums_map = checksums.as_object().unwrap().clone();
138
139 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
141 let prefixed_checksum = format!("sha256:{new_checksum}");
142 checksums_map.insert(filename, Value::String(prefixed_checksum));
143
144 let file = OpenOptions::new()
145 .write(true)
146 .create(true)
147 .truncate(true)
148 .open(checksums_file)?;
149 let writer = BufWriter::new(file);
150 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
151
152 Ok(())
153}
154
155#[cfg(test)]
159mod tests {
160 use std::{
161 fs,
162 io::{BufWriter, Write},
163 net::SocketAddr,
164 sync::Arc,
165 };
166
167 use axum::{http::StatusCode, routing::get, serve, Router};
168 use rstest::*;
169 use serde_json::{json, to_writer};
170 use tempfile::TempDir;
171 use tokio::{
172 net::TcpListener,
173 task,
174 time::{sleep, Duration},
175 };
176
177 use super::*;
178
179 async fn setup_test_server(
180 server_content: Option<String>,
181 status_code: StatusCode,
182 ) -> SocketAddr {
183 let server_content = Arc::new(server_content);
184 let server_content_clone = server_content.clone();
185 let app = Router::new().route(
186 "/testfile.txt",
187 get(move || {
188 let server_content = server_content_clone.clone();
189 async move {
190 let response_body = match &*server_content {
191 Some(content) => content.clone(),
192 None => "File not found".to_string(),
193 };
194 (status_code, response_body)
195 }
196 }),
197 );
198
199 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
200 let addr = listener.local_addr().unwrap();
201 let server = serve(listener, app);
202
203 task::spawn(async move {
204 if let Err(e) = server.await {
205 eprintln!("server error: {e}");
206 }
207 });
208
209 sleep(Duration::from_millis(100)).await;
210
211 addr
212 }
213
214 #[tokio::test]
215 async fn test_file_already_exists() {
216 let temp_dir = TempDir::new().unwrap();
217 let file_path = temp_dir.path().join("testfile.txt");
218 fs::write(&file_path, "Existing file content").unwrap();
219
220 let url = "http://example.com/testfile.txt".to_string();
221 let result = ensure_file_exists_or_download_http(&file_path, &url, None);
222
223 assert!(result.is_ok());
224 let content = fs::read_to_string(&file_path).unwrap();
225 assert_eq!(content, "Existing file content");
226 }
227
228 #[tokio::test]
229 async fn test_download_file_success() {
230 let temp_dir = TempDir::new().unwrap();
231 let filepath = temp_dir.path().join("testfile.txt");
232 let filepath_clone = filepath.clone();
233
234 let server_content = Some("Server file content".to_string());
235 let status_code = StatusCode::OK;
236 let addr = setup_test_server(server_content.clone(), status_code).await;
237 let url = format!("http://{addr}/testfile.txt");
238
239 let result = tokio::task::spawn_blocking(move || {
240 ensure_file_exists_or_download_http(&filepath_clone, &url, None)
241 })
242 .await
243 .unwrap();
244
245 assert!(result.is_ok());
246 let content = fs::read_to_string(&filepath).unwrap();
247 assert_eq!(content, server_content.unwrap());
248 }
249
250 #[tokio::test]
251 async fn test_download_file_not_found() {
252 let temp_dir = TempDir::new().unwrap();
253 let file_path = temp_dir.path().join("testfile.txt");
254
255 let server_content = None;
256 let status_code = StatusCode::NOT_FOUND;
257 let addr = setup_test_server(server_content, status_code).await;
258 let url = format!("http://{addr}/testfile.txt");
259
260 let result = tokio::task::spawn_blocking(move || {
261 ensure_file_exists_or_download_http(&file_path, &url, None)
262 })
263 .await
264 .unwrap();
265
266 assert!(result.is_err());
267 let err_msg = format!("{}", result.unwrap_err());
268 assert!(
269 err_msg.contains("Failed to download file"),
270 "Unexpected error message: {err_msg}"
271 );
272 }
273
274 #[tokio::test]
275 async fn test_network_error() {
276 let temp_dir = TempDir::new().unwrap();
277 let file_path = temp_dir.path().join("testfile.txt");
278
279 let url = "http://127.0.0.1:0/testfile.txt".to_string();
281
282 let result = tokio::task::spawn_blocking(move || {
283 ensure_file_exists_or_download_http(&file_path, &url, None)
284 })
285 .await
286 .unwrap();
287
288 assert!(result.is_err());
289 let err_msg = format!("{}", result.unwrap_err());
290 assert!(
291 err_msg.contains("error"),
292 "Unexpected error message: {err_msg}"
293 );
294 }
295
296 #[rstest]
297 fn test_calculate_sha256() -> anyhow::Result<()> {
298 let temp_dir = TempDir::new()?;
299 let test_file_path = temp_dir.path().join("test_file.txt");
300 let mut test_file = File::create(&test_file_path)?;
301 let content = b"Hello, world!";
302 test_file.write_all(content)?;
303
304 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
305 let calculated_hash = calculate_sha256(&test_file_path)?;
306
307 assert_eq!(calculated_hash, expected_hash);
308 Ok(())
309 }
310
311 #[rstest]
312 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
313 let temp_dir = TempDir::new()?;
314 let test_file_path = temp_dir.path().join("test_file.txt");
315 let mut test_file = File::create(&test_file_path)?;
316 let content = b"Hello, world!";
317 test_file.write_all(content)?;
318
319 let calculated_checksum = calculate_sha256(&test_file_path)?;
320
321 let checksums_path = temp_dir.path().join("checksums.json");
323 let checksums_data = json!({
324 "test_file.txt": format!("sha256:{}", calculated_checksum)
325 });
326 let checksums_file = File::create(&checksums_path)?;
327 let writer = BufWriter::new(checksums_file);
328 to_writer(writer, &checksums_data)?;
329
330 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
331 assert!(is_valid, "The checksum should be valid");
332 Ok(())
333 }
334}