1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23 Method, Response, Url,
24 header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29#[derive(Clone, Debug)]
34#[cfg_attr(
35 feature = "python",
36 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39 inner: StatusCode,
40}
41
42impl HttpStatus {
43 #[must_use]
45 pub const fn new(code: StatusCode) -> Self {
46 Self { inner: code }
47 }
48
49 pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
55 Ok(Self {
56 inner: StatusCode::from_u16(code)?,
57 })
58 }
59
60 #[inline]
62 #[must_use]
63 pub const fn as_u16(&self) -> u16 {
64 self.inner.as_u16()
65 }
66
67 #[inline]
69 #[must_use]
70 pub fn as_str(&self) -> &str {
71 self.inner.as_str()
72 }
73
74 #[inline]
76 #[must_use]
77 pub fn is_informational(&self) -> bool {
78 self.inner.is_informational()
79 }
80
81 #[inline]
83 #[must_use]
84 pub fn is_success(&self) -> bool {
85 self.inner.is_success()
86 }
87
88 #[inline]
90 #[must_use]
91 pub fn is_redirection(&self) -> bool {
92 self.inner.is_redirection()
93 }
94
95 #[inline]
97 #[must_use]
98 pub fn is_client_error(&self) -> bool {
99 self.inner.is_client_error()
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn is_server_error(&self) -> bool {
106 self.inner.is_server_error()
107 }
108}
109
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112#[cfg_attr(
113 feature = "python",
114 pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
115)]
116pub enum HttpMethod {
117 GET,
118 POST,
119 PUT,
120 DELETE,
121 PATCH,
122}
123
124#[allow(clippy::from_over_into)]
125impl Into<Method> for HttpMethod {
126 fn into(self) -> Method {
127 match self {
128 Self::GET => Method::GET,
129 Self::POST => Method::POST,
130 Self::PUT => Method::PUT,
131 Self::DELETE => Method::DELETE,
132 Self::PATCH => Method::PATCH,
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
142#[cfg_attr(
143 feature = "python",
144 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
145)]
146pub struct HttpResponse {
147 pub status: HttpStatus,
149 pub headers: HashMap<String, String>,
151 pub body: Bytes,
153}
154
155#[derive(thiserror::Error, Debug)]
159pub enum HttpClientError {
160 #[error("HTTP error occurred: {0}")]
161 Error(String),
162
163 #[error("HTTP request timed out: {0}")]
164 TimeoutError(String),
165}
166
167impl From<reqwest::Error> for HttpClientError {
168 fn from(source: reqwest::Error) -> Self {
169 if source.is_timeout() {
170 Self::TimeoutError(source.to_string())
171 } else {
172 Self::Error(source.to_string())
173 }
174 }
175}
176
177impl From<String> for HttpClientError {
178 fn from(value: String) -> Self {
179 Self::Error(value)
180 }
181}
182
183#[derive(Clone)]
193#[cfg_attr(
194 feature = "python",
195 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
196)]
197pub struct HttpClient {
198 pub(crate) client: InnerHttpClient,
200 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
202}
203
204impl HttpClient {
205 #[must_use]
207 pub fn new(
208 headers: HashMap<String, String>,
209 header_keys: Vec<String>,
210 keyed_quotas: Vec<(String, Quota)>,
211 default_quota: Option<Quota>,
212 timeout_secs: Option<u64>,
213 ) -> Self {
214 let mut header_map = HeaderMap::new();
216 for (key, value) in headers {
217 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
218 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
219 header_map.insert(header_name, header_value);
220 }
221
222 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
223 if let Some(timeout_secs) = timeout_secs {
224 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
225 }
226
227 let client = client_builder
228 .build()
229 .expect("Failed to build reqwest client");
230
231 let client = InnerHttpClient {
232 client,
233 header_keys: Arc::new(header_keys),
234 };
235 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
236
237 Self {
238 client,
239 rate_limiter,
240 }
241 }
242
243 #[allow(clippy::too_many_arguments)]
260 pub async fn request(
261 &self,
262 method: Method,
263 url: String,
264 headers: Option<HashMap<String, String>>,
265 body: Option<Vec<u8>>,
266 keys: Option<Vec<String>>,
267 timeout_secs: Option<u64>,
268 ) -> Result<HttpResponse, HttpClientError> {
269 let rate_limiter = self.rate_limiter.clone();
270
271 rate_limiter.await_keys_ready(keys).await;
272 self.client
273 .send_request(method, url, headers, body, timeout_secs)
274 .await
275 }
276}
277
278#[derive(Clone, Debug)]
287pub struct InnerHttpClient {
288 pub(crate) client: reqwest::Client,
289 pub(crate) header_keys: Arc<Vec<String>>,
290}
291
292impl InnerHttpClient {
293 pub async fn send_request(
305 &self,
306 method: Method,
307 url: String,
308 headers: Option<HashMap<String, String>>,
309 body: Option<Vec<u8>>,
310 timeout_secs: Option<u64>,
311 ) -> Result<HttpResponse, HttpClientError> {
312 let headers = headers.unwrap_or_default();
313 let reqwest_url = Url::parse(url.as_str())
314 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
315
316 let mut header_map = HeaderMap::new();
317 for (header_key, header_value) in &headers {
318 let key = HeaderName::from_bytes(header_key.as_bytes())
319 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
320 let _ = header_map.insert(
321 key,
322 header_value
323 .parse()
324 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
325 );
326 }
327
328 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
329
330 if let Some(timeout_secs) = timeout_secs {
331 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
332 }
333
334 let request = match body {
335 Some(b) => request_builder
336 .body(b)
337 .build()
338 .map_err(HttpClientError::from)?,
339 None => request_builder.build().map_err(HttpClientError::from)?,
340 };
341
342 tracing::trace!("{request:?}");
343
344 let response = self
345 .client
346 .execute(request)
347 .await
348 .map_err(HttpClientError::from)?;
349
350 self.to_response(response).await
351 }
352
353 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
359 tracing::trace!("{response:?}");
360
361 let headers: HashMap<String, String> = self
362 .header_keys
363 .iter()
364 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
365 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
366 .map(|(k, v)| (k.clone(), v.to_owned()))
367 .collect();
368 let status = HttpStatus::new(response.status());
369 let body = response.bytes().await.map_err(HttpClientError::from)?;
370
371 Ok(HttpResponse {
372 status,
373 headers,
374 body,
375 })
376 }
377}
378
379impl Default for InnerHttpClient {
380 fn default() -> Self {
384 let client = reqwest::Client::new();
385 Self {
386 client,
387 header_keys: Default::default(),
388 }
389 }
390}
391
392#[cfg(test)]
396#[cfg(target_os = "linux")] mod tests {
398 use std::net::{SocketAddr, TcpListener};
399
400 use axum::{
401 Router,
402 routing::{delete, get, patch, post},
403 serve,
404 };
405 use http::status::StatusCode;
406
407 use super::*;
408
409 fn get_unique_port() -> u16 {
410 let listener =
412 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
413 let port = listener.local_addr().unwrap().port();
414
415 drop(listener);
417
418 port
419 }
420
421 fn create_router() -> Router {
422 Router::new()
423 .route("/get", get(|| async { "hello-world!" }))
424 .route("/post", post(|| async { StatusCode::OK }))
425 .route("/patch", patch(|| async { StatusCode::OK }))
426 .route("/delete", delete(|| async { StatusCode::OK }))
427 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
428 .route(
429 "/slow",
430 get(|| async {
431 tokio::time::sleep(Duration::from_secs(2)).await;
432 "Eventually responded"
433 }),
434 )
435 }
436
437 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
438 let port = get_unique_port();
439 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
440 .await
441 .unwrap();
442 let addr = listener.local_addr().unwrap();
443
444 tokio::spawn(async move {
445 serve(listener, create_router()).await.unwrap();
446 });
447
448 Ok(addr)
449 }
450
451 #[tokio::test]
452 async fn test_get() {
453 let addr = start_test_server().await.unwrap();
454 let url = format!("http://{addr}");
455
456 let client = InnerHttpClient::default();
457 let response = client
458 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
459 .await
460 .unwrap();
461
462 assert!(response.status.is_success());
463 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
464 }
465
466 #[tokio::test]
467 async fn test_post() {
468 let addr = start_test_server().await.unwrap();
469 let url = format!("http://{addr}");
470
471 let client = InnerHttpClient::default();
472 let response = client
473 .send_request(
474 reqwest::Method::POST,
475 format!("{url}/post"),
476 None,
477 None,
478 None,
479 )
480 .await
481 .unwrap();
482
483 assert!(response.status.is_success());
484 }
485
486 #[tokio::test]
487 async fn test_post_with_body() {
488 let addr = start_test_server().await.unwrap();
489 let url = format!("http://{addr}");
490
491 let client = InnerHttpClient::default();
492
493 let mut body = HashMap::new();
494 body.insert(
495 "key1".to_string(),
496 serde_json::Value::String("value1".to_string()),
497 );
498 body.insert(
499 "key2".to_string(),
500 serde_json::Value::String("value2".to_string()),
501 );
502
503 let body_string = serde_json::to_string(&body).unwrap();
504 let body_bytes = body_string.into_bytes();
505
506 let response = client
507 .send_request(
508 reqwest::Method::POST,
509 format!("{url}/post"),
510 None,
511 Some(body_bytes),
512 None,
513 )
514 .await
515 .unwrap();
516
517 assert!(response.status.is_success());
518 }
519
520 #[tokio::test]
521 async fn test_patch() {
522 let addr = start_test_server().await.unwrap();
523 let url = format!("http://{addr}");
524
525 let client = InnerHttpClient::default();
526 let response = client
527 .send_request(
528 reqwest::Method::PATCH,
529 format!("{url}/patch"),
530 None,
531 None,
532 None,
533 )
534 .await
535 .unwrap();
536
537 assert!(response.status.is_success());
538 }
539
540 #[tokio::test]
541 async fn test_delete() {
542 let addr = start_test_server().await.unwrap();
543 let url = format!("http://{addr}");
544
545 let client = InnerHttpClient::default();
546 let response = client
547 .send_request(
548 reqwest::Method::DELETE,
549 format!("{url}/delete"),
550 None,
551 None,
552 None,
553 )
554 .await
555 .unwrap();
556
557 assert!(response.status.is_success());
558 }
559
560 #[tokio::test]
561 async fn test_not_found() {
562 let addr = start_test_server().await.unwrap();
563 let url = format!("http://{addr}/notfound");
564 let client = InnerHttpClient::default();
565
566 let response = client
567 .send_request(reqwest::Method::GET, url, None, None, None)
568 .await
569 .unwrap();
570
571 assert!(response.status.is_client_error());
572 assert_eq!(response.status.as_u16(), 404);
573 }
574
575 #[tokio::test]
576 async fn test_timeout() {
577 let addr = start_test_server().await.unwrap();
578 let url = format!("http://{addr}/slow");
579 let client = InnerHttpClient::default();
580
581 let result = client
583 .send_request(reqwest::Method::GET, url, None, None, Some(1))
584 .await;
585
586 match result {
587 Err(HttpClientError::TimeoutError(msg)) => {
588 println!("Got expected timeout error: {msg}");
589 }
590 Err(other) => panic!("Expected a timeout error, got: {other:?}"),
591 Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
592 }
593 }
594}