1use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
19
20use nautilus_core::collections::into_ustr_vec;
21use reqwest::{
22 Method, Response, Url,
23 header::{HeaderMap, HeaderName, HeaderValue},
24};
25use ustr::Ustr;
26
27use super::{HttpClientError, HttpResponse, HttpStatus};
28use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
29
30#[derive(Clone, Debug)]
40#[cfg_attr(
41 feature = "python",
42 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
43)]
44pub struct HttpClient {
45 pub(crate) client: InnerHttpClient,
47 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
49}
50
51impl HttpClient {
52 pub fn new(
59 headers: HashMap<String, String>,
60 header_keys: Vec<String>,
61 keyed_quotas: Vec<(String, Quota)>,
62 default_quota: Option<Quota>,
63 timeout_secs: Option<u64>,
64 proxy_url: Option<String>,
65 ) -> Result<Self, HttpClientError> {
66 let mut header_map = HeaderMap::new();
68 for (key, value) in headers {
69 let header_name = HeaderName::from_str(&key)
70 .map_err(|e| HttpClientError::Error(format!("Invalid header name '{key}': {e}")))?;
71 let header_value = HeaderValue::from_str(&value).map_err(|e| {
72 HttpClientError::Error(format!("Invalid header value '{value}': {e}"))
73 })?;
74 header_map.insert(header_name, header_value);
75 }
76
77 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
78 client_builder = client_builder.tcp_nodelay(true);
79
80 if let Some(timeout_secs) = timeout_secs {
81 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
82 }
83
84 if let Some(proxy_url) = proxy_url {
86 let proxy = reqwest::Proxy::all(&proxy_url)
87 .map_err(|e| HttpClientError::InvalidProxy(format!("{proxy_url}: {e}")))?;
88 client_builder = client_builder.proxy(proxy);
89 }
90
91 let client = client_builder
92 .build()
93 .map_err(|e| HttpClientError::ClientBuildError(e.to_string()))?;
94
95 let client = InnerHttpClient {
96 client,
97 header_keys: Arc::new(header_keys),
98 };
99
100 let keyed_quotas = keyed_quotas
101 .into_iter()
102 .map(|(key, quota)| (Ustr::from(&key), quota))
103 .collect();
104
105 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
106
107 Ok(Self {
108 client,
109 rate_limiter,
110 })
111 }
112
113 #[allow(clippy::too_many_arguments)]
123 pub async fn request(
124 &self,
125 method: Method,
126 url: String,
127 params: Option<&HashMap<String, Vec<String>>>,
128 headers: Option<HashMap<String, String>>,
129 body: Option<Vec<u8>>,
130 timeout_secs: Option<u64>,
131 keys: Option<Vec<String>>,
132 ) -> Result<HttpResponse, HttpClientError> {
133 let keys = keys.map(into_ustr_vec);
134
135 self.request_with_ustr_keys(method, url, params, headers, body, timeout_secs, keys)
136 .await
137 }
138
139 #[allow(clippy::too_many_arguments)]
149 pub async fn request_with_params<P: serde::Serialize>(
150 &self,
151 method: Method,
152 url: String,
153 params: Option<&P>,
154 headers: Option<HashMap<String, String>>,
155 body: Option<Vec<u8>>,
156 timeout_secs: Option<u64>,
157 keys: Option<Vec<String>>,
158 ) -> Result<HttpResponse, HttpClientError> {
159 let keys = keys.map(into_ustr_vec);
160 let rate_limiter = self.rate_limiter.clone();
161 rate_limiter.await_keys_ready(keys).await;
162
163 self.client
164 .send_request_with_query(method, url, params, headers, body, timeout_secs)
165 .await
166 }
167
168 #[allow(clippy::too_many_arguments)]
174 pub async fn request_with_ustr_keys(
175 &self,
176 method: Method,
177 url: String,
178 params: Option<&HashMap<String, Vec<String>>>,
179 headers: Option<HashMap<String, String>>,
180 body: Option<Vec<u8>>,
181 timeout_secs: Option<u64>,
182 keys: Option<Vec<Ustr>>,
183 ) -> Result<HttpResponse, HttpClientError> {
184 let rate_limiter = self.rate_limiter.clone();
185 rate_limiter.await_keys_ready(keys).await;
186
187 self.client
188 .send_request(method, url, params, headers, body, timeout_secs)
189 .await
190 }
191
192 pub async fn get(
198 &self,
199 url: String,
200 params: Option<&HashMap<String, Vec<String>>>,
201 headers: Option<HashMap<String, String>>,
202 timeout_secs: Option<u64>,
203 keys: Option<Vec<String>>,
204 ) -> Result<HttpResponse, HttpClientError> {
205 self.request(Method::GET, url, params, headers, None, timeout_secs, keys)
206 .await
207 }
208
209 pub async fn post(
215 &self,
216 url: String,
217 params: Option<&HashMap<String, Vec<String>>>,
218 headers: Option<HashMap<String, String>>,
219 body: Option<Vec<u8>>,
220 timeout_secs: Option<u64>,
221 keys: Option<Vec<String>>,
222 ) -> Result<HttpResponse, HttpClientError> {
223 self.request(Method::POST, url, params, headers, body, timeout_secs, keys)
224 .await
225 }
226
227 pub async fn patch(
233 &self,
234 url: String,
235 params: Option<&HashMap<String, Vec<String>>>,
236 headers: Option<HashMap<String, String>>,
237 body: Option<Vec<u8>>,
238 timeout_secs: Option<u64>,
239 keys: Option<Vec<String>>,
240 ) -> Result<HttpResponse, HttpClientError> {
241 self.request(
242 Method::PATCH,
243 url,
244 params,
245 headers,
246 body,
247 timeout_secs,
248 keys,
249 )
250 .await
251 }
252
253 pub async fn delete(
259 &self,
260 url: String,
261 params: Option<&HashMap<String, Vec<String>>>,
262 headers: Option<HashMap<String, String>>,
263 timeout_secs: Option<u64>,
264 keys: Option<Vec<String>>,
265 ) -> Result<HttpResponse, HttpClientError> {
266 self.request(
267 Method::DELETE,
268 url,
269 params,
270 headers,
271 None,
272 timeout_secs,
273 keys,
274 )
275 .await
276 }
277}
278
279#[derive(Clone, Debug)]
288pub struct InnerHttpClient {
289 pub(crate) client: reqwest::Client,
290 pub(crate) header_keys: Arc<Vec<String>>,
291}
292
293impl InnerHttpClient {
294 pub async fn send_request(
300 &self,
301 method: Method,
302 url: String,
303 params: Option<&HashMap<String, Vec<String>>>,
304 headers: Option<HashMap<String, String>>,
305 body: Option<Vec<u8>>,
306 timeout_secs: Option<u64>,
307 ) -> Result<HttpResponse, HttpClientError> {
308 let full_url = encode_url_params(&url, params)?;
309 self.send_request_internal(method, full_url, None::<&()>, headers, body, timeout_secs)
310 .await
311 }
312
313 pub async fn send_request_with_query<Q: serde::Serialize>(
322 &self,
323 method: Method,
324 url: String,
325 query: Option<&Q>,
326 headers: Option<HashMap<String, String>>,
327 body: Option<Vec<u8>>,
328 timeout_secs: Option<u64>,
329 ) -> Result<HttpResponse, HttpClientError> {
330 self.send_request_internal(method, url, query, headers, body, timeout_secs)
331 .await
332 }
333
334 async fn send_request_internal<Q: serde::Serialize>(
340 &self,
341 method: Method,
342 url: String,
343 query: Option<&Q>,
344 headers: Option<HashMap<String, String>>,
345 body: Option<Vec<u8>>,
346 timeout_secs: Option<u64>,
347 ) -> Result<HttpResponse, HttpClientError> {
348 let headers = headers.unwrap_or_default();
349 let reqwest_url = Url::parse(url.as_str())
350 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
351
352 let mut header_map = HeaderMap::new();
353 for (header_key, header_value) in &headers {
354 let key = HeaderName::from_bytes(header_key.as_bytes())
355 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
356 if let Some(old_value) = header_map.insert(
357 key.clone(),
358 header_value
359 .parse()
360 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
361 ) {
362 tracing::trace!("Replaced header '{key}': old={old_value:?}, new={header_value}");
363 }
364 }
365
366 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
367
368 if let Some(q) = query {
369 request_builder = request_builder.query(q);
370 }
371
372 if let Some(timeout_secs) = timeout_secs {
373 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
374 }
375
376 let request = match body {
377 Some(b) => request_builder
378 .body(b)
379 .build()
380 .map_err(HttpClientError::from)?,
381 None => request_builder.build().map_err(HttpClientError::from)?,
382 };
383
384 tracing::trace!("{request:?}");
385
386 let response = self
387 .client
388 .execute(request)
389 .await
390 .map_err(HttpClientError::from)?;
391
392 self.to_response(response).await
393 }
394
395 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
401 tracing::trace!("{response:?}");
402
403 let headers: HashMap<String, String> = self
404 .header_keys
405 .iter()
406 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
407 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
408 .map(|(k, v)| (k.clone(), v.to_owned()))
409 .collect();
410 let status = HttpStatus::new(response.status());
411 let body = response.bytes().await.map_err(HttpClientError::from)?;
412
413 Ok(HttpResponse {
414 status,
415 headers,
416 body,
417 })
418 }
419}
420
421impl Default for InnerHttpClient {
422 fn default() -> Self {
426 let client = reqwest::Client::new();
427 Self {
428 client,
429 header_keys: Default::default(),
430 }
431 }
432}
433
434fn encode_url_params(
440 url: &str,
441 params: Option<&HashMap<String, Vec<String>>>,
442) -> Result<String, HttpClientError> {
443 let Some(params) = params else {
444 return Ok(url.to_string());
445 };
446
447 let pairs: Vec<(String, String)> = params
449 .iter()
450 .flat_map(|(key, values)| values.iter().map(move |value| (key.clone(), value.clone())))
451 .collect();
452
453 if pairs.is_empty() {
454 return Ok(url.to_string());
455 }
456
457 let query_string = serde_urlencoded::to_string(pairs)
458 .map_err(|e| HttpClientError::Error(format!("Failed to encode params: {e}")))?;
459
460 let separator = if url.contains('?') { '&' } else { '?' };
462 Ok(format!("{}{}{}", url, separator, query_string))
463}
464
465#[cfg(test)]
469#[cfg(target_os = "linux")] mod tests {
471 use std::net::{SocketAddr, TcpListener};
472
473 use axum::{
474 Router,
475 routing::{delete, get, patch, post},
476 serve,
477 };
478 use http::status::StatusCode;
479 use rstest::rstest;
480
481 use super::*;
482
483 fn get_unique_port() -> u16 {
484 let listener =
486 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
487 let port = listener.local_addr().unwrap().port();
488
489 drop(listener);
491
492 port
493 }
494
495 fn create_router() -> Router {
496 Router::new()
497 .route("/get", get(|| async { "hello-world!" }))
498 .route("/post", post(|| async { StatusCode::OK }))
499 .route("/patch", patch(|| async { StatusCode::OK }))
500 .route("/delete", delete(|| async { StatusCode::OK }))
501 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
502 .route(
503 "/slow",
504 get(|| async {
505 tokio::time::sleep(Duration::from_secs(2)).await;
506 "Eventually responded"
507 }),
508 )
509 }
510
511 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
512 let port = get_unique_port();
513 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
514 .await
515 .unwrap();
516 let addr = listener.local_addr().unwrap();
517
518 tokio::spawn(async move {
519 serve(listener, create_router()).await.unwrap();
520 });
521
522 Ok(addr)
523 }
524
525 #[tokio::test]
526 async fn test_get() {
527 let addr = start_test_server().await.unwrap();
528 let url = format!("http://{addr}");
529
530 let client = InnerHttpClient::default();
531 let response = client
532 .send_request(
533 reqwest::Method::GET,
534 format!("{url}/get"),
535 None,
536 None,
537 None,
538 None,
539 )
540 .await
541 .unwrap();
542
543 assert!(response.status.is_success());
544 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
545 }
546
547 #[tokio::test]
548 async fn test_post() {
549 let addr = start_test_server().await.unwrap();
550 let url = format!("http://{addr}");
551
552 let client = InnerHttpClient::default();
553 let response = client
554 .send_request(
555 reqwest::Method::POST,
556 format!("{url}/post"),
557 None,
558 None,
559 None,
560 None,
561 )
562 .await
563 .unwrap();
564
565 assert!(response.status.is_success());
566 }
567
568 #[tokio::test]
569 async fn test_post_with_body() {
570 let addr = start_test_server().await.unwrap();
571 let url = format!("http://{addr}");
572
573 let client = InnerHttpClient::default();
574
575 let mut body = HashMap::new();
576 body.insert(
577 "key1".to_string(),
578 serde_json::Value::String("value1".to_string()),
579 );
580 body.insert(
581 "key2".to_string(),
582 serde_json::Value::String("value2".to_string()),
583 );
584
585 let body_string = serde_json::to_string(&body).unwrap();
586 let body_bytes = body_string.into_bytes();
587
588 let response = client
589 .send_request(
590 reqwest::Method::POST,
591 format!("{url}/post"),
592 None,
593 None,
594 Some(body_bytes),
595 None,
596 )
597 .await
598 .unwrap();
599
600 assert!(response.status.is_success());
601 }
602
603 #[tokio::test]
604 async fn test_patch() {
605 let addr = start_test_server().await.unwrap();
606 let url = format!("http://{addr}");
607
608 let client = InnerHttpClient::default();
609 let response = client
610 .send_request(
611 reqwest::Method::PATCH,
612 format!("{url}/patch"),
613 None,
614 None,
615 None,
616 None,
617 )
618 .await
619 .unwrap();
620
621 assert!(response.status.is_success());
622 }
623
624 #[tokio::test]
625 async fn test_delete() {
626 let addr = start_test_server().await.unwrap();
627 let url = format!("http://{addr}");
628
629 let client = InnerHttpClient::default();
630 let response = client
631 .send_request(
632 reqwest::Method::DELETE,
633 format!("{url}/delete"),
634 None,
635 None,
636 None,
637 None,
638 )
639 .await
640 .unwrap();
641
642 assert!(response.status.is_success());
643 }
644
645 #[tokio::test]
646 async fn test_not_found() {
647 let addr = start_test_server().await.unwrap();
648 let url = format!("http://{addr}/notfound");
649 let client = InnerHttpClient::default();
650
651 let response = client
652 .send_request(reqwest::Method::GET, url, None, None, None, None)
653 .await
654 .unwrap();
655
656 assert!(response.status.is_client_error());
657 assert_eq!(response.status.as_u16(), 404);
658 }
659
660 #[tokio::test]
661 async fn test_timeout() {
662 let addr = start_test_server().await.unwrap();
663 let url = format!("http://{addr}/slow");
664 let client = InnerHttpClient::default();
665
666 let result = client
668 .send_request(reqwest::Method::GET, url, None, None, None, Some(1))
669 .await;
670
671 match result {
672 Err(HttpClientError::TimeoutError(msg)) => {
673 println!("Got expected timeout error: {msg}");
674 }
675 Err(e) => panic!("Expected a timeout error, was: {e:?}"),
676 Ok(resp) => panic!("Expected a timeout error, but was a successful response: {resp:?}"),
677 }
678 }
679
680 #[rstest]
681 fn test_http_client_without_proxy() {
682 let result = HttpClient::new(
684 HashMap::new(),
685 vec![],
686 vec![],
687 None,
688 None,
689 None, );
691
692 assert!(result.is_ok());
693 }
694
695 #[rstest]
696 fn test_http_client_with_valid_proxy() {
697 let result = HttpClient::new(
699 HashMap::new(),
700 vec![],
701 vec![],
702 None,
703 None,
704 Some("http://proxy.example.com:8080".to_string()),
705 );
706
707 assert!(result.is_ok());
708 }
709
710 #[rstest]
711 fn test_http_client_with_socks5_proxy() {
712 let result = HttpClient::new(
714 HashMap::new(),
715 vec![],
716 vec![],
717 None,
718 None,
719 Some("socks5://127.0.0.1:1080".to_string()),
720 );
721
722 assert!(result.is_ok());
723 }
724
725 #[rstest]
726 fn test_http_client_with_malformed_proxy() {
727 let result = HttpClient::new(
731 HashMap::new(),
732 vec![],
733 vec![],
734 None,
735 None,
736 Some("://invalid".to_string()),
737 );
738
739 assert!(result.is_err());
740 assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
741 }
742
743 #[rstest]
744 fn test_http_client_with_empty_proxy_string() {
745 let result = HttpClient::new(
747 HashMap::new(),
748 vec![],
749 vec![],
750 None,
751 None,
752 Some(String::new()),
753 );
754
755 assert!(result.is_err());
756 assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
757 }
758
759 #[tokio::test]
760 async fn test_http_client_get() {
761 let addr = start_test_server().await.unwrap();
762 let url = format!("http://{addr}/get");
763
764 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
765 let response = client.get(url, None, None, None, None).await.unwrap();
766
767 assert!(response.status.is_success());
768 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
769 }
770
771 #[tokio::test]
772 async fn test_http_client_post() {
773 let addr = start_test_server().await.unwrap();
774 let url = format!("http://{addr}/post");
775
776 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
777 let response = client
778 .post(url, None, None, None, None, None)
779 .await
780 .unwrap();
781
782 assert!(response.status.is_success());
783 }
784
785 #[tokio::test]
786 async fn test_http_client_patch() {
787 let addr = start_test_server().await.unwrap();
788 let url = format!("http://{addr}/patch");
789
790 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
791 let response = client
792 .patch(url, None, None, None, None, None)
793 .await
794 .unwrap();
795
796 assert!(response.status.is_success());
797 }
798
799 #[tokio::test]
800 async fn test_http_client_delete() {
801 let addr = start_test_server().await.unwrap();
802 let url = format!("http://{addr}/delete");
803
804 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
805 let response = client.delete(url, None, None, None, None).await.unwrap();
806
807 assert!(response.status.is_success());
808 }
809}