1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{status::InvalidStatusCode, HeaderValue, StatusCode};
22use reqwest::{
23 header::{HeaderMap, HeaderName},
24 Method, Response, Url,
25};
26
27use crate::ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter};
28
29#[derive(Clone, Debug)]
34#[cfg_attr(
35 feature = "python",
36 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39 inner: StatusCode,
40}
41
42impl HttpStatus {
43 #[must_use]
45 pub const fn new(code: StatusCode) -> Self {
46 Self { inner: code }
47 }
48
49 pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
53 Ok(Self {
54 inner: StatusCode::from_u16(code)?,
55 })
56 }
57
58 #[inline]
60 #[must_use]
61 pub const fn as_u16(&self) -> u16 {
62 self.inner.as_u16()
63 }
64
65 #[inline]
67 #[must_use]
68 pub fn as_str(&self) -> &str {
69 self.inner.as_str()
70 }
71
72 #[inline]
74 #[must_use]
75 pub fn is_informational(&self) -> bool {
76 self.inner.is_informational()
77 }
78
79 #[inline]
81 #[must_use]
82 pub fn is_success(&self) -> bool {
83 self.inner.is_success()
84 }
85
86 #[inline]
88 #[must_use]
89 pub fn is_redirection(&self) -> bool {
90 self.inner.is_redirection()
91 }
92
93 #[inline]
95 #[must_use]
96 pub fn is_client_error(&self) -> bool {
97 self.inner.is_client_error()
98 }
99
100 #[inline]
102 #[must_use]
103 pub fn is_server_error(&self) -> bool {
104 self.inner.is_server_error()
105 }
106}
107
108#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
110#[cfg_attr(
111 feature = "python",
112 pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
113)]
114pub enum HttpMethod {
115 GET,
116 POST,
117 PUT,
118 DELETE,
119 PATCH,
120}
121
122#[allow(clippy::from_over_into)]
123impl Into<Method> for HttpMethod {
124 fn into(self) -> Method {
125 match self {
126 Self::GET => Method::GET,
127 Self::POST => Method::POST,
128 Self::PUT => Method::PUT,
129 Self::DELETE => Method::DELETE,
130 Self::PATCH => Method::PATCH,
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
140#[cfg_attr(
141 feature = "python",
142 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
143)]
144pub struct HttpResponse {
145 pub status: HttpStatus,
147 pub headers: HashMap<String, String>,
149 pub body: Bytes,
151}
152
153#[derive(thiserror::Error, Debug)]
157pub enum HttpClientError {
158 #[error("HTTP error occurred: {0}")]
159 Error(String),
160
161 #[error("HTTP request timed out: {0}")]
162 TimeoutError(String),
163}
164
165impl From<reqwest::Error> for HttpClientError {
166 fn from(source: reqwest::Error) -> Self {
167 if source.is_timeout() {
168 Self::TimeoutError(source.to_string())
169 } else {
170 Self::Error(source.to_string())
171 }
172 }
173}
174
175impl From<String> for HttpClientError {
176 fn from(value: String) -> Self {
177 Self::Error(value)
178 }
179}
180
181#[derive(Clone)]
191#[cfg_attr(
192 feature = "python",
193 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
194)]
195pub struct HttpClient {
196 pub(crate) client: InnerHttpClient,
198 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
200}
201
202impl HttpClient {
203 #[must_use]
205 pub fn new(
206 headers: HashMap<String, String>,
207 header_keys: Vec<String>,
208 keyed_quotas: Vec<(String, Quota)>,
209 default_quota: Option<Quota>,
210 ) -> Self {
211 let mut header_map = HeaderMap::new();
213 for (key, value) in headers {
214 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
215 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
216 header_map.insert(header_name, header_value);
217 }
218
219 let client = reqwest::Client::builder()
220 .default_headers(header_map)
221 .build()
222 .expect("Failed to build reqwest client");
223
224 let client = InnerHttpClient {
225 client,
226 header_keys: Arc::new(header_keys),
227 };
228 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
229
230 Self {
231 client,
232 rate_limiter,
233 }
234 }
235
236 #[allow(clippy::too_many_arguments)]
248 pub async fn request(
249 &self,
250 method: Method,
251 url: String,
252 headers: Option<HashMap<String, String>>,
253 body: Option<Vec<u8>>,
254 keys: Option<Vec<String>>,
255 timeout_secs: Option<u64>,
256 ) -> Result<HttpResponse, HttpClientError> {
257 let rate_limiter = self.rate_limiter.clone();
258
259 rate_limiter.await_keys_ready(keys).await;
260 self.client
261 .send_request(method, url, headers, body, timeout_secs)
262 .await
263 }
264}
265
266#[derive(Clone, Debug)]
275pub struct InnerHttpClient {
276 pub(crate) client: reqwest::Client,
277 pub(crate) header_keys: Arc<Vec<String>>,
278}
279
280impl InnerHttpClient {
281 pub async fn send_request(
289 &self,
290 method: Method,
291 url: String,
292 headers: Option<HashMap<String, String>>,
293 body: Option<Vec<u8>>,
294 timeout_secs: Option<u64>,
295 ) -> Result<HttpResponse, HttpClientError> {
296 let headers = headers.unwrap_or_default();
297 let reqwest_url = Url::parse(url.as_str())
298 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
299
300 let mut header_map = HeaderMap::new();
301 for (header_key, header_value) in &headers {
302 let key = HeaderName::from_bytes(header_key.as_bytes())
303 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
304 let _ = header_map.insert(
305 key,
306 header_value
307 .parse()
308 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
309 );
310 }
311
312 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
313
314 if let Some(timeout_secs) = timeout_secs {
315 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
316 }
317
318 let request = match body {
319 Some(b) => request_builder
320 .body(b)
321 .build()
322 .map_err(HttpClientError::from)?,
323 None => request_builder.build().map_err(HttpClientError::from)?,
324 };
325
326 tracing::trace!("{request:?}");
327
328 let response = self
329 .client
330 .execute(request)
331 .await
332 .map_err(HttpClientError::from)?;
333
334 self.to_response(response).await
335 }
336
337 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
339 tracing::trace!("{response:?}");
340
341 let headers: HashMap<String, String> = self
342 .header_keys
343 .iter()
344 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
345 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
346 .map(|(k, v)| (k.clone(), v.to_owned()))
347 .collect();
348 let status = HttpStatus::new(response.status());
349 let body = response.bytes().await.map_err(HttpClientError::from)?;
350
351 Ok(HttpResponse {
352 status,
353 headers,
354 body,
355 })
356 }
357}
358
359impl Default for InnerHttpClient {
360 fn default() -> Self {
364 let client = reqwest::Client::new();
365 Self {
366 client,
367 header_keys: Default::default(),
368 }
369 }
370}
371
372#[cfg(test)]
376#[cfg(target_os = "linux")] mod tests {
378 use std::net::{SocketAddr, TcpListener};
379
380 use axum::{
381 routing::{delete, get, patch, post},
382 serve, Router,
383 };
384 use http::status::StatusCode;
385
386 use super::*;
387
388 fn get_unique_port() -> u16 {
389 let listener =
391 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
392 let port = listener.local_addr().unwrap().port();
393
394 drop(listener);
396
397 port
398 }
399
400 fn create_router() -> Router {
401 Router::new()
402 .route("/get", get(|| async { "hello-world!" }))
403 .route("/post", post(|| async { StatusCode::OK }))
404 .route("/patch", patch(|| async { StatusCode::OK }))
405 .route("/delete", delete(|| async { StatusCode::OK }))
406 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
407 .route(
408 "/slow",
409 get(|| async {
410 tokio::time::sleep(Duration::from_secs(2)).await;
411 "Eventually responded"
412 }),
413 )
414 }
415
416 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
417 let port = get_unique_port();
418 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
419 .await
420 .unwrap();
421 let addr = listener.local_addr().unwrap();
422
423 tokio::spawn(async move {
424 serve(listener, create_router()).await.unwrap();
425 });
426
427 Ok(addr)
428 }
429
430 #[tokio::test]
431 async fn test_get() {
432 let addr = start_test_server().await.unwrap();
433 let url = format!("http://{addr}");
434
435 let client = InnerHttpClient::default();
436 let response = client
437 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
438 .await
439 .unwrap();
440
441 assert!(response.status.is_success());
442 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
443 }
444
445 #[tokio::test]
446 async fn test_post() {
447 let addr = start_test_server().await.unwrap();
448 let url = format!("http://{addr}");
449
450 let client = InnerHttpClient::default();
451 let response = client
452 .send_request(
453 reqwest::Method::POST,
454 format!("{url}/post"),
455 None,
456 None,
457 None,
458 )
459 .await
460 .unwrap();
461
462 assert!(response.status.is_success());
463 }
464
465 #[tokio::test]
466 async fn test_post_with_body() {
467 let addr = start_test_server().await.unwrap();
468 let url = format!("http://{addr}");
469
470 let client = InnerHttpClient::default();
471
472 let mut body = HashMap::new();
473 body.insert(
474 "key1".to_string(),
475 serde_json::Value::String("value1".to_string()),
476 );
477 body.insert(
478 "key2".to_string(),
479 serde_json::Value::String("value2".to_string()),
480 );
481
482 let body_string = serde_json::to_string(&body).unwrap();
483 let body_bytes = body_string.into_bytes();
484
485 let response = client
486 .send_request(
487 reqwest::Method::POST,
488 format!("{url}/post"),
489 None,
490 Some(body_bytes),
491 None,
492 )
493 .await
494 .unwrap();
495
496 assert!(response.status.is_success());
497 }
498
499 #[tokio::test]
500 async fn test_patch() {
501 let addr = start_test_server().await.unwrap();
502 let url = format!("http://{addr}");
503
504 let client = InnerHttpClient::default();
505 let response = client
506 .send_request(
507 reqwest::Method::PATCH,
508 format!("{url}/patch"),
509 None,
510 None,
511 None,
512 )
513 .await
514 .unwrap();
515
516 assert!(response.status.is_success());
517 }
518
519 #[tokio::test]
520 async fn test_delete() {
521 let addr = start_test_server().await.unwrap();
522 let url = format!("http://{addr}");
523
524 let client = InnerHttpClient::default();
525 let response = client
526 .send_request(
527 reqwest::Method::DELETE,
528 format!("{url}/delete"),
529 None,
530 None,
531 None,
532 )
533 .await
534 .unwrap();
535
536 assert!(response.status.is_success());
537 }
538
539 #[tokio::test]
540 async fn test_not_found() {
541 let addr = start_test_server().await.unwrap();
542 let url = format!("http://{addr}/notfound");
543 let client = InnerHttpClient::default();
544
545 let response = client
546 .send_request(reqwest::Method::GET, url, None, None, None)
547 .await
548 .unwrap();
549
550 assert!(response.status.is_client_error());
551 assert_eq!(response.status.as_u16(), 404);
552 }
553
554 #[tokio::test]
555 async fn test_timeout() {
556 let addr = start_test_server().await.unwrap();
557 let url = format!("http://{addr}/slow");
558 let client = InnerHttpClient::default();
559
560 let result = client
562 .send_request(reqwest::Method::GET, url, None, None, Some(1))
563 .await;
564
565 match result {
566 Err(HttpClientError::TimeoutError(msg)) => {
567 println!("Got expected timeout error: {msg}");
568 }
569 Err(other) => panic!("Expected a timeout error, got: {other:?}"),
570 Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
571 }
572 }
573}