nautilus_network/
http.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A high-performance HTTP client implementation.
17
18use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use nautilus_core::collections::into_ustr_vec;
23use reqwest::{
24    Method, Response, Url,
25    header::{HeaderMap, HeaderName},
26};
27use ustr::Ustr;
28
29use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
30
31/// Represents a HTTP status code.
32///
33/// Wraps [`http::StatusCode`] to expose a Python-compatible type and reuse
34/// its validation and convenience methods.
35#[derive(Clone, Debug)]
36#[cfg_attr(
37    feature = "python",
38    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
39)]
40pub struct HttpStatus {
41    inner: StatusCode,
42}
43
44impl HttpStatus {
45    /// Create a new [`HttpStatus`] instance from a given [`StatusCode`].
46    #[must_use]
47    pub const fn new(code: StatusCode) -> Self {
48        Self { inner: code }
49    }
50
51    /// Returns the status code as a `u16` (e.g., `200` for OK).
52    #[inline]
53    #[must_use]
54    pub const fn as_u16(&self) -> u16 {
55        self.inner.as_u16()
56    }
57
58    /// Returns the three-digit ASCII representation of this status (e.g., `"200"`).
59    #[inline]
60    #[must_use]
61    pub fn as_str(&self) -> &str {
62        self.inner.as_str()
63    }
64
65    /// Checks if this status is in the 1xx (informational) range.
66    #[inline]
67    #[must_use]
68    pub fn is_informational(&self) -> bool {
69        self.inner.is_informational()
70    }
71
72    /// Checks if this status is in the 2xx (success) range.
73    #[inline]
74    #[must_use]
75    pub fn is_success(&self) -> bool {
76        self.inner.is_success()
77    }
78
79    /// Checks if this status is in the 3xx (redirection) range.
80    #[inline]
81    #[must_use]
82    pub fn is_redirection(&self) -> bool {
83        self.inner.is_redirection()
84    }
85
86    /// Checks if this status is in the 4xx (client error) range.
87    #[inline]
88    #[must_use]
89    pub fn is_client_error(&self) -> bool {
90        self.inner.is_client_error()
91    }
92
93    /// Checks if this status is in the 5xx (server error) range.
94    #[inline]
95    #[must_use]
96    pub fn is_server_error(&self) -> bool {
97        self.inner.is_server_error()
98    }
99}
100
101impl TryFrom<u16> for HttpStatus {
102    type Error = InvalidStatusCode;
103
104    /// Attempts to construct a [`HttpStatus`] from a `u16`.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the code is not in the valid `100..999` range.
109    fn try_from(code: u16) -> Result<Self, Self::Error> {
110        Ok(Self {
111            inner: StatusCode::from_u16(code)?,
112        })
113    }
114}
115
116/// Represents the HTTP methods supported by the `HttpClient`.
117#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
118#[cfg_attr(
119    feature = "python",
120    pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
121)]
122pub enum HttpMethod {
123    GET,
124    POST,
125    PUT,
126    DELETE,
127    PATCH,
128}
129
130impl From<HttpMethod> for Method {
131    fn from(value: HttpMethod) -> Self {
132        match value {
133            HttpMethod::GET => Self::GET,
134            HttpMethod::POST => Self::POST,
135            HttpMethod::PUT => Self::PUT,
136            HttpMethod::DELETE => Self::DELETE,
137            HttpMethod::PATCH => Self::PATCH,
138        }
139    }
140}
141
142/// Represents the response from an HTTP request.
143///
144/// This struct encapsulates the status, headers, and body of an HTTP response,
145/// providing easy access to the key components of the response.
146#[derive(Clone, Debug)]
147#[cfg_attr(
148    feature = "python",
149    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
150)]
151pub struct HttpResponse {
152    /// The HTTP status code.
153    pub status: HttpStatus,
154    /// The response headers as a map of key-value pairs.
155    pub headers: HashMap<String, String>,
156    /// The raw response body.
157    pub body: Bytes,
158}
159
160/// Errors returned by the HTTP client.
161///
162/// Includes generic transport errors and timeouts.
163#[derive(thiserror::Error, Debug)]
164pub enum HttpClientError {
165    #[error("HTTP error occurred: {0}")]
166    Error(String),
167
168    #[error("HTTP request timed out: {0}")]
169    TimeoutError(String),
170}
171
172impl From<reqwest::Error> for HttpClientError {
173    fn from(source: reqwest::Error) -> Self {
174        if source.is_timeout() {
175            Self::TimeoutError(source.to_string())
176        } else {
177            Self::Error(source.to_string())
178        }
179    }
180}
181
182impl From<String> for HttpClientError {
183    fn from(value: String) -> Self {
184        Self::Error(value)
185    }
186}
187
188/// An HTTP client that supports rate limiting and timeouts.
189///
190/// Built on `reqwest` for async I/O. Allows per-endpoint and default quotas
191/// through a rate limiter.
192///
193/// This struct is designed to handle HTTP requests efficiently, providing
194/// support for rate limiting, timeouts, and custom headers. The client is
195/// built on top of `reqwest` and can be used for both synchronous and
196/// asynchronous HTTP requests.
197#[derive(Clone, Debug)]
198#[cfg_attr(
199    feature = "python",
200    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
201)]
202pub struct HttpClient {
203    /// The underlying HTTP client used to make requests.
204    pub(crate) client: InnerHttpClient,
205    /// The rate limiter to control the request rate.
206    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
207}
208
209impl HttpClient {
210    /// Creates a new [`HttpClient`] instance.
211    ///
212    /// # Panics
213    ///
214    /// Panics if any header key or value is invalid, or if building the underlying `reqwest::Client` fails.
215    #[must_use]
216    pub fn new(
217        headers: HashMap<String, String>,
218        header_keys: Vec<String>,
219        keyed_quotas: Vec<(String, Quota)>,
220        default_quota: Option<Quota>,
221        timeout_secs: Option<u64>,
222    ) -> Self {
223        // Build default headers
224        let mut header_map = HeaderMap::new();
225        for (key, value) in headers {
226            let header_name = HeaderName::from_str(&key).expect("Invalid header name");
227            let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
228            header_map.insert(header_name, header_value);
229        }
230
231        let mut client_builder = reqwest::Client::builder().default_headers(header_map);
232        if let Some(timeout_secs) = timeout_secs {
233            client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
234        }
235
236        let client = client_builder
237            .build()
238            .expect("Failed to build reqwest client");
239
240        let client = InnerHttpClient {
241            client,
242            header_keys: Arc::new(header_keys),
243        };
244
245        let keyed_quotas = keyed_quotas
246            .into_iter()
247            .map(|(key, quota)| (Ustr::from(&key), quota))
248            .collect();
249
250        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
251
252        Self {
253            client,
254            rate_limiter,
255        }
256    }
257
258    /// Sends an HTTP request.
259    ///
260    /// - `method`: The [`Method`] to use (GET, POST, etc.).
261    /// - `url`: The target URL.
262    /// - `headers`: Additional headers for this request.
263    /// - `body`: Optional request body.
264    /// - `keys`: Rate-limit keys to control request frequency.
265    /// - `timeout_secs`: Optional request timeout in seconds.
266    ///
267    /// # Errors
268    ///
269    /// Returns an error if unable to send request or times out.
270    ///
271    /// # Examples
272    ///
273    /// If requesting `/foo/bar`, pass rate-limit keys `["foo/bar", "foo"]`.
274    #[allow(clippy::too_many_arguments)]
275    pub async fn request(
276        &self,
277        method: Method,
278        url: String,
279        headers: Option<HashMap<String, String>>,
280        body: Option<Vec<u8>>,
281        timeout_secs: Option<u64>,
282        keys: Option<Vec<String>>,
283    ) -> Result<HttpResponse, HttpClientError> {
284        let keys = keys.map(into_ustr_vec);
285
286        self.request_with_ustr_keys(method, url, headers, body, timeout_secs, keys)
287            .await
288    }
289
290    /// Sends an HTTP request using pre-interned rate limiter keys.
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if unable to send the request or the request times out.
295    #[allow(clippy::too_many_arguments)]
296    pub async fn request_with_ustr_keys(
297        &self,
298        method: Method,
299        url: String,
300        headers: Option<HashMap<String, String>>,
301        body: Option<Vec<u8>>,
302        timeout_secs: Option<u64>,
303        keys: Option<Vec<Ustr>>,
304    ) -> Result<HttpResponse, HttpClientError> {
305        let rate_limiter = self.rate_limiter.clone();
306        rate_limiter.await_keys_ready(keys).await;
307
308        self.client
309            .send_request(method, url, headers, body, timeout_secs)
310            .await
311    }
312}
313
314/// Internal implementation backing [`HttpClient`].
315///
316/// The client is backed by a [`reqwest::Client`] which keeps connections alive and
317/// can be cloned cheaply. The client also has a list of header fields to
318/// extract from the response.
319///
320/// The client returns an [`HttpResponse`]. The client filters only the key value
321/// for the give `header_keys`.
322#[derive(Clone, Debug)]
323pub struct InnerHttpClient {
324    pub(crate) client: reqwest::Client,
325    pub(crate) header_keys: Arc<Vec<String>>,
326}
327
328impl InnerHttpClient {
329    /// Sends an HTTP request and returns an [`HttpResponse`].
330    ///
331    /// - `method`: The HTTP method (e.g. GET, POST).
332    /// - `url`: The target URL.
333    /// - `headers`: Extra headers to send.
334    /// - `body`: Optional request body.
335    /// - `timeout_secs`: Optional request timeout in seconds.
336    ///
337    /// # Errors
338    ///
339    /// Returns an error if unable to send request or times out.
340    pub async fn send_request(
341        &self,
342        method: Method,
343        url: String,
344        headers: Option<HashMap<String, String>>,
345        body: Option<Vec<u8>>,
346        timeout_secs: Option<u64>,
347    ) -> Result<HttpResponse, HttpClientError> {
348        let headers = headers.unwrap_or_default();
349        let reqwest_url = Url::parse(url.as_str())
350            .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
351
352        let mut header_map = HeaderMap::new();
353        for (header_key, header_value) in &headers {
354            let key = HeaderName::from_bytes(header_key.as_bytes())
355                .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
356            if let Some(old_value) = header_map.insert(
357                key.clone(),
358                header_value
359                    .parse()
360                    .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
361            ) {
362                tracing::trace!("Replaced header '{key}': old={old_value:?}, new={header_value}");
363            }
364        }
365
366        let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
367
368        if let Some(timeout_secs) = timeout_secs {
369            request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
370        }
371
372        let request = match body {
373            Some(b) => request_builder
374                .body(b)
375                .build()
376                .map_err(HttpClientError::from)?,
377            None => request_builder.build().map_err(HttpClientError::from)?,
378        };
379
380        tracing::trace!("{request:?}");
381
382        let response = self
383            .client
384            .execute(request)
385            .await
386            .map_err(HttpClientError::from)?;
387
388        self.to_response(response).await
389    }
390
391    /// Converts a `reqwest::Response` into an `HttpResponse`.
392    ///
393    /// # Errors
394    ///
395    /// Returns an error if unable to send request or times out.
396    pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
397        tracing::trace!("{response:?}");
398
399        let headers: HashMap<String, String> = self
400            .header_keys
401            .iter()
402            .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
403            .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
404            .map(|(k, v)| (k.clone(), v.to_owned()))
405            .collect();
406        let status = HttpStatus::new(response.status());
407        let body = response.bytes().await.map_err(HttpClientError::from)?;
408
409        Ok(HttpResponse {
410            status,
411            headers,
412            body,
413        })
414    }
415}
416
417impl Default for InnerHttpClient {
418    /// Creates a new default [`InnerHttpClient`] instance.
419    ///
420    /// The default client is initialized with an empty list of header keys and a new `reqwest::Client`.
421    fn default() -> Self {
422        let client = reqwest::Client::new();
423        Self {
424            client,
425            header_keys: Default::default(),
426        }
427    }
428}
429
430////////////////////////////////////////////////////////////////////////////////
431// Tests
432////////////////////////////////////////////////////////////////////////////////
433#[cfg(test)]
434#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
435mod tests {
436    use std::net::{SocketAddr, TcpListener};
437
438    use axum::{
439        Router,
440        routing::{delete, get, patch, post},
441        serve,
442    };
443    use http::status::StatusCode;
444
445    use super::*;
446
447    fn get_unique_port() -> u16 {
448        // Create a temporary TcpListener to get an available port
449        let listener =
450            TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
451        let port = listener.local_addr().unwrap().port();
452
453        // Close the listener to free up the port
454        drop(listener);
455
456        port
457    }
458
459    fn create_router() -> Router {
460        Router::new()
461            .route("/get", get(|| async { "hello-world!" }))
462            .route("/post", post(|| async { StatusCode::OK }))
463            .route("/patch", patch(|| async { StatusCode::OK }))
464            .route("/delete", delete(|| async { StatusCode::OK }))
465            .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
466            .route(
467                "/slow",
468                get(|| async {
469                    tokio::time::sleep(Duration::from_secs(2)).await;
470                    "Eventually responded"
471                }),
472            )
473    }
474
475    async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
476        let port = get_unique_port();
477        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
478            .await
479            .unwrap();
480        let addr = listener.local_addr().unwrap();
481
482        tokio::spawn(async move {
483            serve(listener, create_router()).await.unwrap();
484        });
485
486        Ok(addr)
487    }
488
489    #[tokio::test]
490    async fn test_get() {
491        let addr = start_test_server().await.unwrap();
492        let url = format!("http://{addr}");
493
494        let client = InnerHttpClient::default();
495        let response = client
496            .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
497            .await
498            .unwrap();
499
500        assert!(response.status.is_success());
501        assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
502    }
503
504    #[tokio::test]
505    async fn test_post() {
506        let addr = start_test_server().await.unwrap();
507        let url = format!("http://{addr}");
508
509        let client = InnerHttpClient::default();
510        let response = client
511            .send_request(
512                reqwest::Method::POST,
513                format!("{url}/post"),
514                None,
515                None,
516                None,
517            )
518            .await
519            .unwrap();
520
521        assert!(response.status.is_success());
522    }
523
524    #[tokio::test]
525    async fn test_post_with_body() {
526        let addr = start_test_server().await.unwrap();
527        let url = format!("http://{addr}");
528
529        let client = InnerHttpClient::default();
530
531        let mut body = HashMap::new();
532        body.insert(
533            "key1".to_string(),
534            serde_json::Value::String("value1".to_string()),
535        );
536        body.insert(
537            "key2".to_string(),
538            serde_json::Value::String("value2".to_string()),
539        );
540
541        let body_string = serde_json::to_string(&body).unwrap();
542        let body_bytes = body_string.into_bytes();
543
544        let response = client
545            .send_request(
546                reqwest::Method::POST,
547                format!("{url}/post"),
548                None,
549                Some(body_bytes),
550                None,
551            )
552            .await
553            .unwrap();
554
555        assert!(response.status.is_success());
556    }
557
558    #[tokio::test]
559    async fn test_patch() {
560        let addr = start_test_server().await.unwrap();
561        let url = format!("http://{addr}");
562
563        let client = InnerHttpClient::default();
564        let response = client
565            .send_request(
566                reqwest::Method::PATCH,
567                format!("{url}/patch"),
568                None,
569                None,
570                None,
571            )
572            .await
573            .unwrap();
574
575        assert!(response.status.is_success());
576    }
577
578    #[tokio::test]
579    async fn test_delete() {
580        let addr = start_test_server().await.unwrap();
581        let url = format!("http://{addr}");
582
583        let client = InnerHttpClient::default();
584        let response = client
585            .send_request(
586                reqwest::Method::DELETE,
587                format!("{url}/delete"),
588                None,
589                None,
590                None,
591            )
592            .await
593            .unwrap();
594
595        assert!(response.status.is_success());
596    }
597
598    #[tokio::test]
599    async fn test_not_found() {
600        let addr = start_test_server().await.unwrap();
601        let url = format!("http://{addr}/notfound");
602        let client = InnerHttpClient::default();
603
604        let response = client
605            .send_request(reqwest::Method::GET, url, None, None, None)
606            .await
607            .unwrap();
608
609        assert!(response.status.is_client_error());
610        assert_eq!(response.status.as_u16(), 404);
611    }
612
613    #[tokio::test]
614    async fn test_timeout() {
615        let addr = start_test_server().await.unwrap();
616        let url = format!("http://{addr}/slow");
617        let client = InnerHttpClient::default();
618
619        // We'll set a 1-second timeout for a route that sleeps 2 seconds
620        let result = client
621            .send_request(reqwest::Method::GET, url, None, None, Some(1))
622            .await;
623
624        match result {
625            Err(HttpClientError::TimeoutError(msg)) => {
626                println!("Got expected timeout error: {msg}");
627            }
628            Err(other) => panic!("Expected a timeout error, was: {other:?}"),
629            Ok(resp) => panic!("Expected a timeout error, but was a successful response: {resp:?}"),
630        }
631    }
632}