1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use nautilus_core::collections::into_ustr_vec;
23use reqwest::{
24 Method, Response, Url,
25 header::{HeaderMap, HeaderName},
26};
27use ustr::Ustr;
28
29use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
30
31#[derive(Clone, Debug)]
36#[cfg_attr(
37 feature = "python",
38 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
39)]
40pub struct HttpStatus {
41 inner: StatusCode,
42}
43
44impl HttpStatus {
45 #[must_use]
47 pub const fn new(code: StatusCode) -> Self {
48 Self { inner: code }
49 }
50
51 #[inline]
53 #[must_use]
54 pub const fn as_u16(&self) -> u16 {
55 self.inner.as_u16()
56 }
57
58 #[inline]
60 #[must_use]
61 pub fn as_str(&self) -> &str {
62 self.inner.as_str()
63 }
64
65 #[inline]
67 #[must_use]
68 pub fn is_informational(&self) -> bool {
69 self.inner.is_informational()
70 }
71
72 #[inline]
74 #[must_use]
75 pub fn is_success(&self) -> bool {
76 self.inner.is_success()
77 }
78
79 #[inline]
81 #[must_use]
82 pub fn is_redirection(&self) -> bool {
83 self.inner.is_redirection()
84 }
85
86 #[inline]
88 #[must_use]
89 pub fn is_client_error(&self) -> bool {
90 self.inner.is_client_error()
91 }
92
93 #[inline]
95 #[must_use]
96 pub fn is_server_error(&self) -> bool {
97 self.inner.is_server_error()
98 }
99}
100
101impl TryFrom<u16> for HttpStatus {
102 type Error = InvalidStatusCode;
103
104 fn try_from(code: u16) -> Result<Self, Self::Error> {
110 Ok(Self {
111 inner: StatusCode::from_u16(code)?,
112 })
113 }
114}
115
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
118#[cfg_attr(
119 feature = "python",
120 pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
121)]
122pub enum HttpMethod {
123 GET,
124 POST,
125 PUT,
126 DELETE,
127 PATCH,
128}
129
130impl From<HttpMethod> for Method {
131 fn from(value: HttpMethod) -> Self {
132 match value {
133 HttpMethod::GET => Self::GET,
134 HttpMethod::POST => Self::POST,
135 HttpMethod::PUT => Self::PUT,
136 HttpMethod::DELETE => Self::DELETE,
137 HttpMethod::PATCH => Self::PATCH,
138 }
139 }
140}
141
142#[derive(Clone, Debug)]
147#[cfg_attr(
148 feature = "python",
149 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
150)]
151pub struct HttpResponse {
152 pub status: HttpStatus,
154 pub headers: HashMap<String, String>,
156 pub body: Bytes,
158}
159
160#[derive(thiserror::Error, Debug)]
164pub enum HttpClientError {
165 #[error("HTTP error occurred: {0}")]
166 Error(String),
167
168 #[error("HTTP request timed out: {0}")]
169 TimeoutError(String),
170}
171
172impl From<reqwest::Error> for HttpClientError {
173 fn from(source: reqwest::Error) -> Self {
174 if source.is_timeout() {
175 Self::TimeoutError(source.to_string())
176 } else {
177 Self::Error(source.to_string())
178 }
179 }
180}
181
182impl From<String> for HttpClientError {
183 fn from(value: String) -> Self {
184 Self::Error(value)
185 }
186}
187
188#[derive(Clone, Debug)]
198#[cfg_attr(
199 feature = "python",
200 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
201)]
202pub struct HttpClient {
203 pub(crate) client: InnerHttpClient,
205 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
207}
208
209impl HttpClient {
210 #[must_use]
216 pub fn new(
217 headers: HashMap<String, String>,
218 header_keys: Vec<String>,
219 keyed_quotas: Vec<(String, Quota)>,
220 default_quota: Option<Quota>,
221 timeout_secs: Option<u64>,
222 ) -> Self {
223 let mut header_map = HeaderMap::new();
225 for (key, value) in headers {
226 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
227 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
228 header_map.insert(header_name, header_value);
229 }
230
231 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
232 if let Some(timeout_secs) = timeout_secs {
233 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
234 }
235
236 let client = client_builder
237 .build()
238 .expect("Failed to build reqwest client");
239
240 let client = InnerHttpClient {
241 client,
242 header_keys: Arc::new(header_keys),
243 };
244
245 let keyed_quotas = keyed_quotas
246 .into_iter()
247 .map(|(key, quota)| (Ustr::from(&key), quota))
248 .collect();
249
250 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
251
252 Self {
253 client,
254 rate_limiter,
255 }
256 }
257
258 #[allow(clippy::too_many_arguments)]
275 pub async fn request(
276 &self,
277 method: Method,
278 url: String,
279 headers: Option<HashMap<String, String>>,
280 body: Option<Vec<u8>>,
281 timeout_secs: Option<u64>,
282 keys: Option<Vec<String>>,
283 ) -> Result<HttpResponse, HttpClientError> {
284 let keys = keys.map(into_ustr_vec);
285
286 self.request_with_ustr_keys(method, url, headers, body, timeout_secs, keys)
287 .await
288 }
289
290 #[allow(clippy::too_many_arguments)]
296 pub async fn request_with_ustr_keys(
297 &self,
298 method: Method,
299 url: String,
300 headers: Option<HashMap<String, String>>,
301 body: Option<Vec<u8>>,
302 timeout_secs: Option<u64>,
303 keys: Option<Vec<Ustr>>,
304 ) -> Result<HttpResponse, HttpClientError> {
305 let rate_limiter = self.rate_limiter.clone();
306 rate_limiter.await_keys_ready(keys).await;
307
308 self.client
309 .send_request(method, url, headers, body, timeout_secs)
310 .await
311 }
312}
313
314#[derive(Clone, Debug)]
323pub struct InnerHttpClient {
324 pub(crate) client: reqwest::Client,
325 pub(crate) header_keys: Arc<Vec<String>>,
326}
327
328impl InnerHttpClient {
329 pub async fn send_request(
341 &self,
342 method: Method,
343 url: String,
344 headers: Option<HashMap<String, String>>,
345 body: Option<Vec<u8>>,
346 timeout_secs: Option<u64>,
347 ) -> Result<HttpResponse, HttpClientError> {
348 let headers = headers.unwrap_or_default();
349 let reqwest_url = Url::parse(url.as_str())
350 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
351
352 let mut header_map = HeaderMap::new();
353 for (header_key, header_value) in &headers {
354 let key = HeaderName::from_bytes(header_key.as_bytes())
355 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
356 if let Some(old_value) = header_map.insert(
357 key.clone(),
358 header_value
359 .parse()
360 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
361 ) {
362 tracing::trace!("Replaced header '{key}': old={old_value:?}, new={header_value}");
363 }
364 }
365
366 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
367
368 if let Some(timeout_secs) = timeout_secs {
369 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
370 }
371
372 let request = match body {
373 Some(b) => request_builder
374 .body(b)
375 .build()
376 .map_err(HttpClientError::from)?,
377 None => request_builder.build().map_err(HttpClientError::from)?,
378 };
379
380 tracing::trace!("{request:?}");
381
382 let response = self
383 .client
384 .execute(request)
385 .await
386 .map_err(HttpClientError::from)?;
387
388 self.to_response(response).await
389 }
390
391 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
397 tracing::trace!("{response:?}");
398
399 let headers: HashMap<String, String> = self
400 .header_keys
401 .iter()
402 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
403 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
404 .map(|(k, v)| (k.clone(), v.to_owned()))
405 .collect();
406 let status = HttpStatus::new(response.status());
407 let body = response.bytes().await.map_err(HttpClientError::from)?;
408
409 Ok(HttpResponse {
410 status,
411 headers,
412 body,
413 })
414 }
415}
416
417impl Default for InnerHttpClient {
418 fn default() -> Self {
422 let client = reqwest::Client::new();
423 Self {
424 client,
425 header_keys: Default::default(),
426 }
427 }
428}
429
430#[cfg(test)]
434#[cfg(target_os = "linux")] mod tests {
436 use std::net::{SocketAddr, TcpListener};
437
438 use axum::{
439 Router,
440 routing::{delete, get, patch, post},
441 serve,
442 };
443 use http::status::StatusCode;
444
445 use super::*;
446
447 fn get_unique_port() -> u16 {
448 let listener =
450 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
451 let port = listener.local_addr().unwrap().port();
452
453 drop(listener);
455
456 port
457 }
458
459 fn create_router() -> Router {
460 Router::new()
461 .route("/get", get(|| async { "hello-world!" }))
462 .route("/post", post(|| async { StatusCode::OK }))
463 .route("/patch", patch(|| async { StatusCode::OK }))
464 .route("/delete", delete(|| async { StatusCode::OK }))
465 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
466 .route(
467 "/slow",
468 get(|| async {
469 tokio::time::sleep(Duration::from_secs(2)).await;
470 "Eventually responded"
471 }),
472 )
473 }
474
475 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
476 let port = get_unique_port();
477 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
478 .await
479 .unwrap();
480 let addr = listener.local_addr().unwrap();
481
482 tokio::spawn(async move {
483 serve(listener, create_router()).await.unwrap();
484 });
485
486 Ok(addr)
487 }
488
489 #[tokio::test]
490 async fn test_get() {
491 let addr = start_test_server().await.unwrap();
492 let url = format!("http://{addr}");
493
494 let client = InnerHttpClient::default();
495 let response = client
496 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
497 .await
498 .unwrap();
499
500 assert!(response.status.is_success());
501 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
502 }
503
504 #[tokio::test]
505 async fn test_post() {
506 let addr = start_test_server().await.unwrap();
507 let url = format!("http://{addr}");
508
509 let client = InnerHttpClient::default();
510 let response = client
511 .send_request(
512 reqwest::Method::POST,
513 format!("{url}/post"),
514 None,
515 None,
516 None,
517 )
518 .await
519 .unwrap();
520
521 assert!(response.status.is_success());
522 }
523
524 #[tokio::test]
525 async fn test_post_with_body() {
526 let addr = start_test_server().await.unwrap();
527 let url = format!("http://{addr}");
528
529 let client = InnerHttpClient::default();
530
531 let mut body = HashMap::new();
532 body.insert(
533 "key1".to_string(),
534 serde_json::Value::String("value1".to_string()),
535 );
536 body.insert(
537 "key2".to_string(),
538 serde_json::Value::String("value2".to_string()),
539 );
540
541 let body_string = serde_json::to_string(&body).unwrap();
542 let body_bytes = body_string.into_bytes();
543
544 let response = client
545 .send_request(
546 reqwest::Method::POST,
547 format!("{url}/post"),
548 None,
549 Some(body_bytes),
550 None,
551 )
552 .await
553 .unwrap();
554
555 assert!(response.status.is_success());
556 }
557
558 #[tokio::test]
559 async fn test_patch() {
560 let addr = start_test_server().await.unwrap();
561 let url = format!("http://{addr}");
562
563 let client = InnerHttpClient::default();
564 let response = client
565 .send_request(
566 reqwest::Method::PATCH,
567 format!("{url}/patch"),
568 None,
569 None,
570 None,
571 )
572 .await
573 .unwrap();
574
575 assert!(response.status.is_success());
576 }
577
578 #[tokio::test]
579 async fn test_delete() {
580 let addr = start_test_server().await.unwrap();
581 let url = format!("http://{addr}");
582
583 let client = InnerHttpClient::default();
584 let response = client
585 .send_request(
586 reqwest::Method::DELETE,
587 format!("{url}/delete"),
588 None,
589 None,
590 None,
591 )
592 .await
593 .unwrap();
594
595 assert!(response.status.is_success());
596 }
597
598 #[tokio::test]
599 async fn test_not_found() {
600 let addr = start_test_server().await.unwrap();
601 let url = format!("http://{addr}/notfound");
602 let client = InnerHttpClient::default();
603
604 let response = client
605 .send_request(reqwest::Method::GET, url, None, None, None)
606 .await
607 .unwrap();
608
609 assert!(response.status.is_client_error());
610 assert_eq!(response.status.as_u16(), 404);
611 }
612
613 #[tokio::test]
614 async fn test_timeout() {
615 let addr = start_test_server().await.unwrap();
616 let url = format!("http://{addr}/slow");
617 let client = InnerHttpClient::default();
618
619 let result = client
621 .send_request(reqwest::Method::GET, url, None, None, Some(1))
622 .await;
623
624 match result {
625 Err(HttpClientError::TimeoutError(msg)) => {
626 println!("Got expected timeout error: {msg}");
627 }
628 Err(other) => panic!("Expected a timeout error, was: {other:?}"),
629 Ok(resp) => panic!("Expected a timeout error, but was a successful response: {resp:?}"),
630 }
631 }
632}