1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23 Method, Response, Url,
24 header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29#[derive(Clone, Debug)]
34#[cfg_attr(
35 feature = "python",
36 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39 inner: StatusCode,
40}
41
42impl HttpStatus {
43 #[must_use]
45 pub const fn new(code: StatusCode) -> Self {
46 Self { inner: code }
47 }
48
49 pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
55 Ok(Self {
56 inner: StatusCode::from_u16(code)?,
57 })
58 }
59
60 #[inline]
62 #[must_use]
63 pub const fn as_u16(&self) -> u16 {
64 self.inner.as_u16()
65 }
66
67 #[inline]
69 #[must_use]
70 pub fn as_str(&self) -> &str {
71 self.inner.as_str()
72 }
73
74 #[inline]
76 #[must_use]
77 pub fn is_informational(&self) -> bool {
78 self.inner.is_informational()
79 }
80
81 #[inline]
83 #[must_use]
84 pub fn is_success(&self) -> bool {
85 self.inner.is_success()
86 }
87
88 #[inline]
90 #[must_use]
91 pub fn is_redirection(&self) -> bool {
92 self.inner.is_redirection()
93 }
94
95 #[inline]
97 #[must_use]
98 pub fn is_client_error(&self) -> bool {
99 self.inner.is_client_error()
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn is_server_error(&self) -> bool {
106 self.inner.is_server_error()
107 }
108}
109
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112#[cfg_attr(
113 feature = "python",
114 pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
115)]
116pub enum HttpMethod {
117 GET,
118 POST,
119 PUT,
120 DELETE,
121 PATCH,
122}
123
124#[allow(clippy::from_over_into)]
125impl Into<Method> for HttpMethod {
126 fn into(self) -> Method {
127 match self {
128 Self::GET => Method::GET,
129 Self::POST => Method::POST,
130 Self::PUT => Method::PUT,
131 Self::DELETE => Method::DELETE,
132 Self::PATCH => Method::PATCH,
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
142#[cfg_attr(
143 feature = "python",
144 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
145)]
146pub struct HttpResponse {
147 pub status: HttpStatus,
149 pub headers: HashMap<String, String>,
151 pub body: Bytes,
153}
154
155#[derive(thiserror::Error, Debug)]
159pub enum HttpClientError {
160 #[error("HTTP error occurred: {0}")]
161 Error(String),
162
163 #[error("HTTP request timed out: {0}")]
164 TimeoutError(String),
165}
166
167impl From<reqwest::Error> for HttpClientError {
168 fn from(source: reqwest::Error) -> Self {
169 if source.is_timeout() {
170 Self::TimeoutError(source.to_string())
171 } else {
172 Self::Error(source.to_string())
173 }
174 }
175}
176
177impl From<String> for HttpClientError {
178 fn from(value: String) -> Self {
179 Self::Error(value)
180 }
181}
182
183#[derive(Clone)]
193#[cfg_attr(
194 feature = "python",
195 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
196)]
197pub struct HttpClient {
198 pub(crate) client: InnerHttpClient,
200 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
202}
203
204impl HttpClient {
205 #[must_use]
207 pub fn new(
208 headers: HashMap<String, String>,
209 header_keys: Vec<String>,
210 keyed_quotas: Vec<(String, Quota)>,
211 default_quota: Option<Quota>,
212 ) -> Self {
213 let mut header_map = HeaderMap::new();
215 for (key, value) in headers {
216 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
217 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
218 header_map.insert(header_name, header_value);
219 }
220
221 let client = reqwest::Client::builder()
222 .default_headers(header_map)
223 .build()
224 .expect("Failed to build reqwest client");
225
226 let client = InnerHttpClient {
227 client,
228 header_keys: Arc::new(header_keys),
229 };
230 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
231
232 Self {
233 client,
234 rate_limiter,
235 }
236 }
237
238 #[allow(clippy::too_many_arguments)]
255 pub async fn request(
256 &self,
257 method: Method,
258 url: String,
259 headers: Option<HashMap<String, String>>,
260 body: Option<Vec<u8>>,
261 keys: Option<Vec<String>>,
262 timeout_secs: Option<u64>,
263 ) -> Result<HttpResponse, HttpClientError> {
264 let rate_limiter = self.rate_limiter.clone();
265
266 rate_limiter.await_keys_ready(keys).await;
267 self.client
268 .send_request(method, url, headers, body, timeout_secs)
269 .await
270 }
271}
272
273#[derive(Clone, Debug)]
282pub struct InnerHttpClient {
283 pub(crate) client: reqwest::Client,
284 pub(crate) header_keys: Arc<Vec<String>>,
285}
286
287impl InnerHttpClient {
288 pub async fn send_request(
300 &self,
301 method: Method,
302 url: String,
303 headers: Option<HashMap<String, String>>,
304 body: Option<Vec<u8>>,
305 timeout_secs: Option<u64>,
306 ) -> Result<HttpResponse, HttpClientError> {
307 let headers = headers.unwrap_or_default();
308 let reqwest_url = Url::parse(url.as_str())
309 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
310
311 let mut header_map = HeaderMap::new();
312 for (header_key, header_value) in &headers {
313 let key = HeaderName::from_bytes(header_key.as_bytes())
314 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
315 let _ = header_map.insert(
316 key,
317 header_value
318 .parse()
319 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
320 );
321 }
322
323 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
324
325 if let Some(timeout_secs) = timeout_secs {
326 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
327 }
328
329 let request = match body {
330 Some(b) => request_builder
331 .body(b)
332 .build()
333 .map_err(HttpClientError::from)?,
334 None => request_builder.build().map_err(HttpClientError::from)?,
335 };
336
337 tracing::trace!("{request:?}");
338
339 let response = self
340 .client
341 .execute(request)
342 .await
343 .map_err(HttpClientError::from)?;
344
345 self.to_response(response).await
346 }
347
348 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
354 tracing::trace!("{response:?}");
355
356 let headers: HashMap<String, String> = self
357 .header_keys
358 .iter()
359 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
360 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
361 .map(|(k, v)| (k.clone(), v.to_owned()))
362 .collect();
363 let status = HttpStatus::new(response.status());
364 let body = response.bytes().await.map_err(HttpClientError::from)?;
365
366 Ok(HttpResponse {
367 status,
368 headers,
369 body,
370 })
371 }
372}
373
374impl Default for InnerHttpClient {
375 fn default() -> Self {
379 let client = reqwest::Client::new();
380 Self {
381 client,
382 header_keys: Default::default(),
383 }
384 }
385}
386
387#[cfg(test)]
391#[cfg(target_os = "linux")] mod tests {
393 use std::net::{SocketAddr, TcpListener};
394
395 use axum::{
396 Router,
397 routing::{delete, get, patch, post},
398 serve,
399 };
400 use http::status::StatusCode;
401
402 use super::*;
403
404 fn get_unique_port() -> u16 {
405 let listener =
407 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
408 let port = listener.local_addr().unwrap().port();
409
410 drop(listener);
412
413 port
414 }
415
416 fn create_router() -> Router {
417 Router::new()
418 .route("/get", get(|| async { "hello-world!" }))
419 .route("/post", post(|| async { StatusCode::OK }))
420 .route("/patch", patch(|| async { StatusCode::OK }))
421 .route("/delete", delete(|| async { StatusCode::OK }))
422 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
423 .route(
424 "/slow",
425 get(|| async {
426 tokio::time::sleep(Duration::from_secs(2)).await;
427 "Eventually responded"
428 }),
429 )
430 }
431
432 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
433 let port = get_unique_port();
434 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
435 .await
436 .unwrap();
437 let addr = listener.local_addr().unwrap();
438
439 tokio::spawn(async move {
440 serve(listener, create_router()).await.unwrap();
441 });
442
443 Ok(addr)
444 }
445
446 #[tokio::test]
447 async fn test_get() {
448 let addr = start_test_server().await.unwrap();
449 let url = format!("http://{addr}");
450
451 let client = InnerHttpClient::default();
452 let response = client
453 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
454 .await
455 .unwrap();
456
457 assert!(response.status.is_success());
458 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
459 }
460
461 #[tokio::test]
462 async fn test_post() {
463 let addr = start_test_server().await.unwrap();
464 let url = format!("http://{addr}");
465
466 let client = InnerHttpClient::default();
467 let response = client
468 .send_request(
469 reqwest::Method::POST,
470 format!("{url}/post"),
471 None,
472 None,
473 None,
474 )
475 .await
476 .unwrap();
477
478 assert!(response.status.is_success());
479 }
480
481 #[tokio::test]
482 async fn test_post_with_body() {
483 let addr = start_test_server().await.unwrap();
484 let url = format!("http://{addr}");
485
486 let client = InnerHttpClient::default();
487
488 let mut body = HashMap::new();
489 body.insert(
490 "key1".to_string(),
491 serde_json::Value::String("value1".to_string()),
492 );
493 body.insert(
494 "key2".to_string(),
495 serde_json::Value::String("value2".to_string()),
496 );
497
498 let body_string = serde_json::to_string(&body).unwrap();
499 let body_bytes = body_string.into_bytes();
500
501 let response = client
502 .send_request(
503 reqwest::Method::POST,
504 format!("{url}/post"),
505 None,
506 Some(body_bytes),
507 None,
508 )
509 .await
510 .unwrap();
511
512 assert!(response.status.is_success());
513 }
514
515 #[tokio::test]
516 async fn test_patch() {
517 let addr = start_test_server().await.unwrap();
518 let url = format!("http://{addr}");
519
520 let client = InnerHttpClient::default();
521 let response = client
522 .send_request(
523 reqwest::Method::PATCH,
524 format!("{url}/patch"),
525 None,
526 None,
527 None,
528 )
529 .await
530 .unwrap();
531
532 assert!(response.status.is_success());
533 }
534
535 #[tokio::test]
536 async fn test_delete() {
537 let addr = start_test_server().await.unwrap();
538 let url = format!("http://{addr}");
539
540 let client = InnerHttpClient::default();
541 let response = client
542 .send_request(
543 reqwest::Method::DELETE,
544 format!("{url}/delete"),
545 None,
546 None,
547 None,
548 )
549 .await
550 .unwrap();
551
552 assert!(response.status.is_success());
553 }
554
555 #[tokio::test]
556 async fn test_not_found() {
557 let addr = start_test_server().await.unwrap();
558 let url = format!("http://{addr}/notfound");
559 let client = InnerHttpClient::default();
560
561 let response = client
562 .send_request(reqwest::Method::GET, url, None, None, None)
563 .await
564 .unwrap();
565
566 assert!(response.status.is_client_error());
567 assert_eq!(response.status.as_u16(), 404);
568 }
569
570 #[tokio::test]
571 async fn test_timeout() {
572 let addr = start_test_server().await.unwrap();
573 let url = format!("http://{addr}/slow");
574 let client = InnerHttpClient::default();
575
576 let result = client
578 .send_request(reqwest::Method::GET, url, None, None, Some(1))
579 .await;
580
581 match result {
582 Err(HttpClientError::TimeoutError(msg)) => {
583 println!("Got expected timeout error: {msg}");
584 }
585 Err(other) => panic!("Expected a timeout error, got: {other:?}"),
586 Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
587 }
588 }
589}