1use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
19
20use nautilus_core::collections::into_ustr_vec;
21use nautilus_cryptography::providers::install_cryptographic_provider;
22use reqwest::{
23 Method, Response, Url,
24 header::{HeaderMap, HeaderName, HeaderValue},
25};
26use ustr::Ustr;
27
28use super::{HttpClientError, HttpResponse, HttpStatus};
29use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
30
31#[derive(Clone, Debug)]
41#[cfg_attr(
42 feature = "python",
43 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
44)]
45pub struct HttpClient {
46 pub(crate) client: InnerHttpClient,
48 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
50}
51
52impl HttpClient {
53 pub fn new(
60 headers: HashMap<String, String>,
61 header_keys: Vec<String>,
62 keyed_quotas: Vec<(String, Quota)>,
63 default_quota: Option<Quota>,
64 timeout_secs: Option<u64>,
65 proxy_url: Option<String>,
66 ) -> Result<Self, HttpClientError> {
67 install_cryptographic_provider();
68
69 let mut header_map = HeaderMap::new();
71 for (key, value) in headers {
72 let header_name = HeaderName::from_str(&key)
73 .map_err(|e| HttpClientError::Error(format!("Invalid header name '{key}': {e}")))?;
74 let header_value = HeaderValue::from_str(&value).map_err(|e| {
75 HttpClientError::Error(format!("Invalid header value '{value}': {e}"))
76 })?;
77 header_map.insert(header_name, header_value);
78 }
79
80 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
81 client_builder = client_builder.tcp_nodelay(true);
82
83 if let Some(timeout_secs) = timeout_secs {
84 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
85 }
86
87 if let Some(proxy_url) = proxy_url {
89 let proxy = reqwest::Proxy::all(&proxy_url)
90 .map_err(|e| HttpClientError::InvalidProxy(format!("{proxy_url}: {e}")))?;
91 client_builder = client_builder.proxy(proxy);
92 }
93
94 let client = client_builder
95 .build()
96 .map_err(|e| HttpClientError::ClientBuildError(e.to_string()))?;
97
98 let client = InnerHttpClient {
99 client,
100 header_keys: Arc::new(header_keys),
101 };
102
103 let keyed_quotas = keyed_quotas
104 .into_iter()
105 .map(|(key, quota)| (Ustr::from(&key), quota))
106 .collect();
107
108 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
109
110 Ok(Self {
111 client,
112 rate_limiter,
113 })
114 }
115
116 #[allow(clippy::too_many_arguments)]
126 pub async fn request(
127 &self,
128 method: Method,
129 url: String,
130 params: Option<&HashMap<String, Vec<String>>>,
131 headers: Option<HashMap<String, String>>,
132 body: Option<Vec<u8>>,
133 timeout_secs: Option<u64>,
134 keys: Option<Vec<String>>,
135 ) -> Result<HttpResponse, HttpClientError> {
136 let keys = keys.map(into_ustr_vec);
137
138 self.request_with_ustr_keys(method, url, params, headers, body, timeout_secs, keys)
139 .await
140 }
141
142 #[allow(clippy::too_many_arguments)]
152 pub async fn request_with_params<P: serde::Serialize>(
153 &self,
154 method: Method,
155 url: String,
156 params: Option<&P>,
157 headers: Option<HashMap<String, String>>,
158 body: Option<Vec<u8>>,
159 timeout_secs: Option<u64>,
160 keys: Option<Vec<String>>,
161 ) -> Result<HttpResponse, HttpClientError> {
162 let keys = keys.map(into_ustr_vec);
163 let rate_limiter = self.rate_limiter.clone();
164 rate_limiter.await_keys_ready(keys).await;
165
166 self.client
167 .send_request_with_query(method, url, params, headers, body, timeout_secs)
168 .await
169 }
170
171 #[allow(clippy::too_many_arguments)]
177 pub async fn request_with_ustr_keys(
178 &self,
179 method: Method,
180 url: String,
181 params: Option<&HashMap<String, Vec<String>>>,
182 headers: Option<HashMap<String, String>>,
183 body: Option<Vec<u8>>,
184 timeout_secs: Option<u64>,
185 keys: Option<Vec<Ustr>>,
186 ) -> Result<HttpResponse, HttpClientError> {
187 let rate_limiter = self.rate_limiter.clone();
188 rate_limiter.await_keys_ready(keys).await;
189
190 self.client
191 .send_request(method, url, params, headers, body, timeout_secs)
192 .await
193 }
194
195 pub async fn get(
201 &self,
202 url: String,
203 params: Option<&HashMap<String, Vec<String>>>,
204 headers: Option<HashMap<String, String>>,
205 timeout_secs: Option<u64>,
206 keys: Option<Vec<String>>,
207 ) -> Result<HttpResponse, HttpClientError> {
208 self.request(Method::GET, url, params, headers, None, timeout_secs, keys)
209 .await
210 }
211
212 pub async fn post(
218 &self,
219 url: String,
220 params: Option<&HashMap<String, Vec<String>>>,
221 headers: Option<HashMap<String, String>>,
222 body: Option<Vec<u8>>,
223 timeout_secs: Option<u64>,
224 keys: Option<Vec<String>>,
225 ) -> Result<HttpResponse, HttpClientError> {
226 self.request(Method::POST, url, params, headers, body, timeout_secs, keys)
227 .await
228 }
229
230 pub async fn patch(
236 &self,
237 url: String,
238 params: Option<&HashMap<String, Vec<String>>>,
239 headers: Option<HashMap<String, String>>,
240 body: Option<Vec<u8>>,
241 timeout_secs: Option<u64>,
242 keys: Option<Vec<String>>,
243 ) -> Result<HttpResponse, HttpClientError> {
244 self.request(
245 Method::PATCH,
246 url,
247 params,
248 headers,
249 body,
250 timeout_secs,
251 keys,
252 )
253 .await
254 }
255
256 pub async fn delete(
262 &self,
263 url: String,
264 params: Option<&HashMap<String, Vec<String>>>,
265 headers: Option<HashMap<String, String>>,
266 timeout_secs: Option<u64>,
267 keys: Option<Vec<String>>,
268 ) -> Result<HttpResponse, HttpClientError> {
269 self.request(
270 Method::DELETE,
271 url,
272 params,
273 headers,
274 None,
275 timeout_secs,
276 keys,
277 )
278 .await
279 }
280}
281
282#[derive(Clone, Debug)]
291pub struct InnerHttpClient {
292 pub(crate) client: reqwest::Client,
293 pub(crate) header_keys: Arc<Vec<String>>,
294}
295
296impl InnerHttpClient {
297 pub async fn send_request(
303 &self,
304 method: Method,
305 url: String,
306 params: Option<&HashMap<String, Vec<String>>>,
307 headers: Option<HashMap<String, String>>,
308 body: Option<Vec<u8>>,
309 timeout_secs: Option<u64>,
310 ) -> Result<HttpResponse, HttpClientError> {
311 let full_url = encode_url_params(&url, params)?;
312 self.send_request_internal(method, full_url, None::<&()>, headers, body, timeout_secs)
313 .await
314 }
315
316 pub async fn send_request_with_query<Q: serde::Serialize>(
325 &self,
326 method: Method,
327 url: String,
328 query: Option<&Q>,
329 headers: Option<HashMap<String, String>>,
330 body: Option<Vec<u8>>,
331 timeout_secs: Option<u64>,
332 ) -> Result<HttpResponse, HttpClientError> {
333 self.send_request_internal(method, url, query, headers, body, timeout_secs)
334 .await
335 }
336
337 async fn send_request_internal<Q: serde::Serialize>(
343 &self,
344 method: Method,
345 url: String,
346 query: Option<&Q>,
347 headers: Option<HashMap<String, String>>,
348 body: Option<Vec<u8>>,
349 timeout_secs: Option<u64>,
350 ) -> Result<HttpResponse, HttpClientError> {
351 let headers = headers.unwrap_or_default();
352 let reqwest_url = Url::parse(url.as_str())
353 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
354
355 let mut header_map = HeaderMap::new();
356 for (header_key, header_value) in &headers {
357 let key = HeaderName::from_bytes(header_key.as_bytes())
358 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
359 if let Some(old_value) = header_map.insert(
360 key.clone(),
361 header_value
362 .parse()
363 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
364 ) {
365 tracing::trace!("Replaced header '{key}': old={old_value:?}, new={header_value}");
366 }
367 }
368
369 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
370
371 if let Some(q) = query {
372 request_builder = request_builder.query(q);
373 }
374
375 if let Some(timeout_secs) = timeout_secs {
376 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
377 }
378
379 let request = match body {
380 Some(b) => request_builder
381 .body(b)
382 .build()
383 .map_err(HttpClientError::from)?,
384 None => request_builder.build().map_err(HttpClientError::from)?,
385 };
386
387 tracing::trace!("{request:?}");
388
389 let response = self
390 .client
391 .execute(request)
392 .await
393 .map_err(HttpClientError::from)?;
394
395 self.to_response(response).await
396 }
397
398 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
404 tracing::trace!("{response:?}");
405
406 let headers: HashMap<String, String> = self
407 .header_keys
408 .iter()
409 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
410 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
411 .map(|(k, v)| (k.clone(), v.to_owned()))
412 .collect();
413 let status = HttpStatus::new(response.status());
414 let body = response.bytes().await.map_err(HttpClientError::from)?;
415
416 Ok(HttpResponse {
417 status,
418 headers,
419 body,
420 })
421 }
422}
423
424impl Default for InnerHttpClient {
425 fn default() -> Self {
429 install_cryptographic_provider();
430 let client = reqwest::Client::new();
431 Self {
432 client,
433 header_keys: Default::default(),
434 }
435 }
436}
437
438fn encode_url_params(
444 url: &str,
445 params: Option<&HashMap<String, Vec<String>>>,
446) -> Result<String, HttpClientError> {
447 let Some(params) = params else {
448 return Ok(url.to_string());
449 };
450
451 let pairs: Vec<(String, String)> = params
453 .iter()
454 .flat_map(|(key, values)| values.iter().map(move |value| (key.clone(), value.clone())))
455 .collect();
456
457 if pairs.is_empty() {
458 return Ok(url.to_string());
459 }
460
461 let query_string = serde_urlencoded::to_string(pairs)
462 .map_err(|e| HttpClientError::Error(format!("Failed to encode params: {e}")))?;
463
464 let separator = if url.contains('?') { '&' } else { '?' };
466 Ok(format!("{url}{separator}{query_string}"))
467}
468
469#[cfg(test)]
470#[cfg(target_os = "linux")] mod tests {
472 use std::net::{SocketAddr, TcpListener};
473
474 use axum::{
475 Router,
476 routing::{delete, get, patch, post},
477 serve,
478 };
479 use http::status::StatusCode;
480 use rstest::rstest;
481
482 use super::*;
483
484 fn get_unique_port() -> u16 {
485 let listener =
487 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
488 let port = listener.local_addr().unwrap().port();
489
490 drop(listener);
492
493 port
494 }
495
496 fn create_router() -> Router {
497 Router::new()
498 .route("/get", get(|| async { "hello-world!" }))
499 .route("/post", post(|| async { StatusCode::OK }))
500 .route("/patch", patch(|| async { StatusCode::OK }))
501 .route("/delete", delete(|| async { StatusCode::OK }))
502 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
503 .route(
504 "/slow",
505 get(|| async {
506 tokio::time::sleep(Duration::from_secs(2)).await;
507 "Eventually responded"
508 }),
509 )
510 }
511
512 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
513 let port = get_unique_port();
514 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
515 .await
516 .unwrap();
517 let addr = listener.local_addr().unwrap();
518
519 tokio::spawn(async move {
520 serve(listener, create_router()).await.unwrap();
521 });
522
523 Ok(addr)
524 }
525
526 #[tokio::test]
527 async fn test_get() {
528 let addr = start_test_server().await.unwrap();
529 let url = format!("http://{addr}");
530
531 let client = InnerHttpClient::default();
532 let response = client
533 .send_request(
534 reqwest::Method::GET,
535 format!("{url}/get"),
536 None,
537 None,
538 None,
539 None,
540 )
541 .await
542 .unwrap();
543
544 assert!(response.status.is_success());
545 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
546 }
547
548 #[tokio::test]
549 async fn test_post() {
550 let addr = start_test_server().await.unwrap();
551 let url = format!("http://{addr}");
552
553 let client = InnerHttpClient::default();
554 let response = client
555 .send_request(
556 reqwest::Method::POST,
557 format!("{url}/post"),
558 None,
559 None,
560 None,
561 None,
562 )
563 .await
564 .unwrap();
565
566 assert!(response.status.is_success());
567 }
568
569 #[tokio::test]
570 async fn test_post_with_body() {
571 let addr = start_test_server().await.unwrap();
572 let url = format!("http://{addr}");
573
574 let client = InnerHttpClient::default();
575
576 let mut body = HashMap::new();
577 body.insert(
578 "key1".to_string(),
579 serde_json::Value::String("value1".to_string()),
580 );
581 body.insert(
582 "key2".to_string(),
583 serde_json::Value::String("value2".to_string()),
584 );
585
586 let body_string = serde_json::to_string(&body).unwrap();
587 let body_bytes = body_string.into_bytes();
588
589 let response = client
590 .send_request(
591 reqwest::Method::POST,
592 format!("{url}/post"),
593 None,
594 None,
595 Some(body_bytes),
596 None,
597 )
598 .await
599 .unwrap();
600
601 assert!(response.status.is_success());
602 }
603
604 #[tokio::test]
605 async fn test_patch() {
606 let addr = start_test_server().await.unwrap();
607 let url = format!("http://{addr}");
608
609 let client = InnerHttpClient::default();
610 let response = client
611 .send_request(
612 reqwest::Method::PATCH,
613 format!("{url}/patch"),
614 None,
615 None,
616 None,
617 None,
618 )
619 .await
620 .unwrap();
621
622 assert!(response.status.is_success());
623 }
624
625 #[tokio::test]
626 async fn test_delete() {
627 let addr = start_test_server().await.unwrap();
628 let url = format!("http://{addr}");
629
630 let client = InnerHttpClient::default();
631 let response = client
632 .send_request(
633 reqwest::Method::DELETE,
634 format!("{url}/delete"),
635 None,
636 None,
637 None,
638 None,
639 )
640 .await
641 .unwrap();
642
643 assert!(response.status.is_success());
644 }
645
646 #[tokio::test]
647 async fn test_not_found() {
648 let addr = start_test_server().await.unwrap();
649 let url = format!("http://{addr}/notfound");
650 let client = InnerHttpClient::default();
651
652 let response = client
653 .send_request(reqwest::Method::GET, url, None, None, None, None)
654 .await
655 .unwrap();
656
657 assert!(response.status.is_client_error());
658 assert_eq!(response.status.as_u16(), 404);
659 }
660
661 #[tokio::test]
662 async fn test_timeout() {
663 let addr = start_test_server().await.unwrap();
664 let url = format!("http://{addr}/slow");
665 let client = InnerHttpClient::default();
666
667 let result = client
669 .send_request(reqwest::Method::GET, url, None, None, None, Some(1))
670 .await;
671
672 match result {
673 Err(HttpClientError::TimeoutError(msg)) => {
674 println!("Got expected timeout error: {msg}");
675 }
676 Err(e) => panic!("Expected a timeout error, was: {e:?}"),
677 Ok(resp) => panic!("Expected a timeout error, but was a successful response: {resp:?}"),
678 }
679 }
680
681 #[rstest]
682 fn test_http_client_without_proxy() {
683 let result = HttpClient::new(
685 HashMap::new(),
686 vec![],
687 vec![],
688 None,
689 None,
690 None, );
692
693 assert!(result.is_ok());
694 }
695
696 #[rstest]
697 fn test_http_client_with_valid_proxy() {
698 let result = HttpClient::new(
700 HashMap::new(),
701 vec![],
702 vec![],
703 None,
704 None,
705 Some("http://proxy.example.com:8080".to_string()),
706 );
707
708 assert!(result.is_ok());
709 }
710
711 #[rstest]
712 fn test_http_client_with_socks5_proxy() {
713 let result = HttpClient::new(
715 HashMap::new(),
716 vec![],
717 vec![],
718 None,
719 None,
720 Some("socks5://127.0.0.1:1080".to_string()),
721 );
722
723 assert!(result.is_ok());
724 }
725
726 #[rstest]
727 fn test_http_client_with_malformed_proxy() {
728 let result = HttpClient::new(
732 HashMap::new(),
733 vec![],
734 vec![],
735 None,
736 None,
737 Some("://invalid".to_string()),
738 );
739
740 assert!(result.is_err());
741 assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
742 }
743
744 #[rstest]
745 fn test_http_client_with_empty_proxy_string() {
746 let result = HttpClient::new(
748 HashMap::new(),
749 vec![],
750 vec![],
751 None,
752 None,
753 Some(String::new()),
754 );
755
756 assert!(result.is_err());
757 assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
758 }
759
760 #[tokio::test]
761 async fn test_http_client_get() {
762 let addr = start_test_server().await.unwrap();
763 let url = format!("http://{addr}/get");
764
765 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
766 let response = client.get(url, None, None, None, None).await.unwrap();
767
768 assert!(response.status.is_success());
769 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
770 }
771
772 #[tokio::test]
773 async fn test_http_client_post() {
774 let addr = start_test_server().await.unwrap();
775 let url = format!("http://{addr}/post");
776
777 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
778 let response = client
779 .post(url, None, None, None, None, None)
780 .await
781 .unwrap();
782
783 assert!(response.status.is_success());
784 }
785
786 #[tokio::test]
787 async fn test_http_client_patch() {
788 let addr = start_test_server().await.unwrap();
789 let url = format!("http://{addr}/patch");
790
791 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
792 let response = client
793 .patch(url, None, None, None, None, None)
794 .await
795 .unwrap();
796
797 assert!(response.status.is_success());
798 }
799
800 #[tokio::test]
801 async fn test_http_client_delete() {
802 let addr = start_test_server().await.unwrap();
803 let url = format!("http://{addr}/delete");
804
805 let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
806 let response = client.delete(url, None, None, None, None).await.unwrap();
807
808 assert!(response.status.is_success());
809 }
810}