1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23 Method, Response, Url,
24 header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29#[derive(Clone, Debug)]
34#[cfg_attr(
35 feature = "python",
36 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39 inner: StatusCode,
40}
41
42impl HttpStatus {
43 #[must_use]
45 pub const fn new(code: StatusCode) -> Self {
46 Self { inner: code }
47 }
48
49 #[inline]
51 #[must_use]
52 pub const fn as_u16(&self) -> u16 {
53 self.inner.as_u16()
54 }
55
56 #[inline]
58 #[must_use]
59 pub fn as_str(&self) -> &str {
60 self.inner.as_str()
61 }
62
63 #[inline]
65 #[must_use]
66 pub fn is_informational(&self) -> bool {
67 self.inner.is_informational()
68 }
69
70 #[inline]
72 #[must_use]
73 pub fn is_success(&self) -> bool {
74 self.inner.is_success()
75 }
76
77 #[inline]
79 #[must_use]
80 pub fn is_redirection(&self) -> bool {
81 self.inner.is_redirection()
82 }
83
84 #[inline]
86 #[must_use]
87 pub fn is_client_error(&self) -> bool {
88 self.inner.is_client_error()
89 }
90
91 #[inline]
93 #[must_use]
94 pub fn is_server_error(&self) -> bool {
95 self.inner.is_server_error()
96 }
97}
98
99impl TryFrom<u16> for HttpStatus {
100 type Error = InvalidStatusCode;
101
102 fn try_from(code: u16) -> Result<Self, Self::Error> {
108 Ok(Self {
109 inner: StatusCode::from_u16(code)?,
110 })
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116#[cfg_attr(
117 feature = "python",
118 pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
119)]
120pub enum HttpMethod {
121 GET,
122 POST,
123 PUT,
124 DELETE,
125 PATCH,
126}
127
128impl From<HttpMethod> for Method {
129 fn from(value: HttpMethod) -> Self {
130 match value {
131 HttpMethod::GET => Self::GET,
132 HttpMethod::POST => Self::POST,
133 HttpMethod::PUT => Self::PUT,
134 HttpMethod::DELETE => Self::DELETE,
135 HttpMethod::PATCH => Self::PATCH,
136 }
137 }
138}
139
140#[derive(Clone, Debug)]
145#[cfg_attr(
146 feature = "python",
147 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
148)]
149pub struct HttpResponse {
150 pub status: HttpStatus,
152 pub headers: HashMap<String, String>,
154 pub body: Bytes,
156}
157
158#[derive(thiserror::Error, Debug)]
162pub enum HttpClientError {
163 #[error("HTTP error occurred: {0}")]
164 Error(String),
165
166 #[error("HTTP request timed out: {0}")]
167 TimeoutError(String),
168}
169
170impl From<reqwest::Error> for HttpClientError {
171 fn from(source: reqwest::Error) -> Self {
172 if source.is_timeout() {
173 Self::TimeoutError(source.to_string())
174 } else {
175 Self::Error(source.to_string())
176 }
177 }
178}
179
180impl From<String> for HttpClientError {
181 fn from(value: String) -> Self {
182 Self::Error(value)
183 }
184}
185
186#[derive(Clone, Debug)]
196#[cfg_attr(
197 feature = "python",
198 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
199)]
200pub struct HttpClient {
201 pub(crate) client: InnerHttpClient,
203 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
205}
206
207impl HttpClient {
208 #[must_use]
214 pub fn new(
215 headers: HashMap<String, String>,
216 header_keys: Vec<String>,
217 keyed_quotas: Vec<(String, Quota)>,
218 default_quota: Option<Quota>,
219 timeout_secs: Option<u64>,
220 ) -> Self {
221 let mut header_map = HeaderMap::new();
223 for (key, value) in headers {
224 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
225 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
226 header_map.insert(header_name, header_value);
227 }
228
229 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
230 if let Some(timeout_secs) = timeout_secs {
231 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
232 }
233
234 let client = client_builder
235 .build()
236 .expect("Failed to build reqwest client");
237
238 let client = InnerHttpClient {
239 client,
240 header_keys: Arc::new(header_keys),
241 };
242
243 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
244
245 Self {
246 client,
247 rate_limiter,
248 }
249 }
250
251 #[allow(clippy::too_many_arguments)]
268 pub async fn request(
269 &self,
270 method: Method,
271 url: String,
272 headers: Option<HashMap<String, String>>,
273 body: Option<Vec<u8>>,
274 timeout_secs: Option<u64>,
275 keys: Option<Vec<String>>,
276 ) -> Result<HttpResponse, HttpClientError> {
277 let rate_limiter = self.rate_limiter.clone();
278 rate_limiter.await_keys_ready(keys).await;
279
280 self.client
281 .send_request(method, url, headers, body, timeout_secs)
282 .await
283 }
284}
285
286#[derive(Clone, Debug)]
295pub struct InnerHttpClient {
296 pub(crate) client: reqwest::Client,
297 pub(crate) header_keys: Arc<Vec<String>>,
298}
299
300impl InnerHttpClient {
301 pub async fn send_request(
313 &self,
314 method: Method,
315 url: String,
316 headers: Option<HashMap<String, String>>,
317 body: Option<Vec<u8>>,
318 timeout_secs: Option<u64>,
319 ) -> Result<HttpResponse, HttpClientError> {
320 let headers = headers.unwrap_or_default();
321 let reqwest_url = Url::parse(url.as_str())
322 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
323
324 let mut header_map = HeaderMap::new();
325 for (header_key, header_value) in &headers {
326 let key = HeaderName::from_bytes(header_key.as_bytes())
327 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
328 let _ = header_map.insert(
329 key,
330 header_value
331 .parse()
332 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
333 );
334 }
335
336 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
337
338 if let Some(timeout_secs) = timeout_secs {
339 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
340 }
341
342 let request = match body {
343 Some(b) => request_builder
344 .body(b)
345 .build()
346 .map_err(HttpClientError::from)?,
347 None => request_builder.build().map_err(HttpClientError::from)?,
348 };
349
350 tracing::trace!("{request:?}");
351
352 let response = self
353 .client
354 .execute(request)
355 .await
356 .map_err(HttpClientError::from)?;
357
358 self.to_response(response).await
359 }
360
361 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
367 tracing::trace!("{response:?}");
368
369 let headers: HashMap<String, String> = self
370 .header_keys
371 .iter()
372 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
373 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
374 .map(|(k, v)| (k.clone(), v.to_owned()))
375 .collect();
376 let status = HttpStatus::new(response.status());
377 let body = response.bytes().await.map_err(HttpClientError::from)?;
378
379 Ok(HttpResponse {
380 status,
381 headers,
382 body,
383 })
384 }
385}
386
387impl Default for InnerHttpClient {
388 fn default() -> Self {
392 let client = reqwest::Client::new();
393 Self {
394 client,
395 header_keys: Default::default(),
396 }
397 }
398}
399
400#[cfg(test)]
404#[cfg(target_os = "linux")] mod tests {
406 use std::net::{SocketAddr, TcpListener};
407
408 use axum::{
409 Router,
410 routing::{delete, get, patch, post},
411 serve,
412 };
413 use http::status::StatusCode;
414
415 use super::*;
416
417 fn get_unique_port() -> u16 {
418 let listener =
420 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
421 let port = listener.local_addr().unwrap().port();
422
423 drop(listener);
425
426 port
427 }
428
429 fn create_router() -> Router {
430 Router::new()
431 .route("/get", get(|| async { "hello-world!" }))
432 .route("/post", post(|| async { StatusCode::OK }))
433 .route("/patch", patch(|| async { StatusCode::OK }))
434 .route("/delete", delete(|| async { StatusCode::OK }))
435 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
436 .route(
437 "/slow",
438 get(|| async {
439 tokio::time::sleep(Duration::from_secs(2)).await;
440 "Eventually responded"
441 }),
442 )
443 }
444
445 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
446 let port = get_unique_port();
447 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
448 .await
449 .unwrap();
450 let addr = listener.local_addr().unwrap();
451
452 tokio::spawn(async move {
453 serve(listener, create_router()).await.unwrap();
454 });
455
456 Ok(addr)
457 }
458
459 #[tokio::test]
460 async fn test_get() {
461 let addr = start_test_server().await.unwrap();
462 let url = format!("http://{addr}");
463
464 let client = InnerHttpClient::default();
465 let response = client
466 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
467 .await
468 .unwrap();
469
470 assert!(response.status.is_success());
471 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
472 }
473
474 #[tokio::test]
475 async fn test_post() {
476 let addr = start_test_server().await.unwrap();
477 let url = format!("http://{addr}");
478
479 let client = InnerHttpClient::default();
480 let response = client
481 .send_request(
482 reqwest::Method::POST,
483 format!("{url}/post"),
484 None,
485 None,
486 None,
487 )
488 .await
489 .unwrap();
490
491 assert!(response.status.is_success());
492 }
493
494 #[tokio::test]
495 async fn test_post_with_body() {
496 let addr = start_test_server().await.unwrap();
497 let url = format!("http://{addr}");
498
499 let client = InnerHttpClient::default();
500
501 let mut body = HashMap::new();
502 body.insert(
503 "key1".to_string(),
504 serde_json::Value::String("value1".to_string()),
505 );
506 body.insert(
507 "key2".to_string(),
508 serde_json::Value::String("value2".to_string()),
509 );
510
511 let body_string = serde_json::to_string(&body).unwrap();
512 let body_bytes = body_string.into_bytes();
513
514 let response = client
515 .send_request(
516 reqwest::Method::POST,
517 format!("{url}/post"),
518 None,
519 Some(body_bytes),
520 None,
521 )
522 .await
523 .unwrap();
524
525 assert!(response.status.is_success());
526 }
527
528 #[tokio::test]
529 async fn test_patch() {
530 let addr = start_test_server().await.unwrap();
531 let url = format!("http://{addr}");
532
533 let client = InnerHttpClient::default();
534 let response = client
535 .send_request(
536 reqwest::Method::PATCH,
537 format!("{url}/patch"),
538 None,
539 None,
540 None,
541 )
542 .await
543 .unwrap();
544
545 assert!(response.status.is_success());
546 }
547
548 #[tokio::test]
549 async fn test_delete() {
550 let addr = start_test_server().await.unwrap();
551 let url = format!("http://{addr}");
552
553 let client = InnerHttpClient::default();
554 let response = client
555 .send_request(
556 reqwest::Method::DELETE,
557 format!("{url}/delete"),
558 None,
559 None,
560 None,
561 )
562 .await
563 .unwrap();
564
565 assert!(response.status.is_success());
566 }
567
568 #[tokio::test]
569 async fn test_not_found() {
570 let addr = start_test_server().await.unwrap();
571 let url = format!("http://{addr}/notfound");
572 let client = InnerHttpClient::default();
573
574 let response = client
575 .send_request(reqwest::Method::GET, url, None, None, None)
576 .await
577 .unwrap();
578
579 assert!(response.status.is_client_error());
580 assert_eq!(response.status.as_u16(), 404);
581 }
582
583 #[tokio::test]
584 async fn test_timeout() {
585 let addr = start_test_server().await.unwrap();
586 let url = format!("http://{addr}/slow");
587 let client = InnerHttpClient::default();
588
589 let result = client
591 .send_request(reqwest::Method::GET, url, None, None, Some(1))
592 .await;
593
594 match result {
595 Err(HttpClientError::TimeoutError(msg)) => {
596 println!("Got expected timeout error: {msg}");
597 }
598 Err(other) => panic!("Expected a timeout error, got: {other:?}"),
599 Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
600 }
601 }
602}