nautilus_network/
http.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A high-performance HTTP client implementation.
17
18use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{status::InvalidStatusCode, HeaderValue, StatusCode};
22use reqwest::{
23    header::{HeaderMap, HeaderName},
24    Method, Response, Url,
25};
26
27use crate::ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter};
28
29/// Represents a HTTP status code.
30///
31/// Wraps [`http::StatusCode`] to expose a Python-compatible type and reuse
32/// its validation and convenience methods.
33#[derive(Clone, Debug)]
34#[cfg_attr(
35    feature = "python",
36    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39    inner: StatusCode,
40}
41
42impl HttpStatus {
43    /// Create a new [`HttpStatus`] instance from a given [`StatusCode`].
44    #[must_use]
45    pub const fn new(code: StatusCode) -> Self {
46        Self { inner: code }
47    }
48
49    /// Attempts to construct a [`HttpStatus`] from a `u16`.
50    ///
51    /// Returns an error if the code is not in the valid `100..999` range.
52    pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
53        Ok(Self {
54            inner: StatusCode::from_u16(code)?,
55        })
56    }
57
58    /// Returns the status code as a `u16` (e.g., `200` for OK).
59    #[inline]
60    #[must_use]
61    pub const fn as_u16(&self) -> u16 {
62        self.inner.as_u16()
63    }
64
65    /// Returns the three-digit ASCII representation of this status (e.g., `"200"`).
66    #[inline]
67    #[must_use]
68    pub fn as_str(&self) -> &str {
69        self.inner.as_str()
70    }
71
72    /// Checks if this status is in the 1xx (informational) range.
73    #[inline]
74    #[must_use]
75    pub fn is_informational(&self) -> bool {
76        self.inner.is_informational()
77    }
78
79    /// Checks if this status is in the 2xx (success) range.
80    #[inline]
81    #[must_use]
82    pub fn is_success(&self) -> bool {
83        self.inner.is_success()
84    }
85
86    /// Checks if this status is in the 3xx (redirection) range.
87    #[inline]
88    #[must_use]
89    pub fn is_redirection(&self) -> bool {
90        self.inner.is_redirection()
91    }
92
93    /// Checks if this status is in the 4xx (client error) range.
94    #[inline]
95    #[must_use]
96    pub fn is_client_error(&self) -> bool {
97        self.inner.is_client_error()
98    }
99
100    /// Checks if this status is in the 5xx (server error) range.
101    #[inline]
102    #[must_use]
103    pub fn is_server_error(&self) -> bool {
104        self.inner.is_server_error()
105    }
106}
107
108/// Represents the HTTP methods supported by the `HttpClient`.
109#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
110#[cfg_attr(
111    feature = "python",
112    pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
113)]
114pub enum HttpMethod {
115    GET,
116    POST,
117    PUT,
118    DELETE,
119    PATCH,
120}
121
122#[allow(clippy::from_over_into)]
123impl Into<Method> for HttpMethod {
124    fn into(self) -> Method {
125        match self {
126            Self::GET => Method::GET,
127            Self::POST => Method::POST,
128            Self::PUT => Method::PUT,
129            Self::DELETE => Method::DELETE,
130            Self::PATCH => Method::PATCH,
131        }
132    }
133}
134
135/// Represents the response from an HTTP request.
136///
137/// This struct encapsulates the status, headers, and body of an HTTP response,
138/// providing easy access to the key components of the response.
139#[derive(Clone, Debug)]
140#[cfg_attr(
141    feature = "python",
142    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
143)]
144pub struct HttpResponse {
145    /// The HTTP status code.
146    pub status: HttpStatus,
147    /// The response headers as a map of key-value pairs.
148    pub headers: HashMap<String, String>,
149    /// The raw response body.
150    pub body: Bytes,
151}
152
153/// Errors returned by the HTTP client.
154///
155/// Includes generic transport errors and timeouts.
156#[derive(thiserror::Error, Debug)]
157pub enum HttpClientError {
158    #[error("HTTP error occurred: {0}")]
159    Error(String),
160
161    #[error("HTTP request timed out: {0}")]
162    TimeoutError(String),
163}
164
165impl From<reqwest::Error> for HttpClientError {
166    fn from(source: reqwest::Error) -> Self {
167        if source.is_timeout() {
168            Self::TimeoutError(source.to_string())
169        } else {
170            Self::Error(source.to_string())
171        }
172    }
173}
174
175impl From<String> for HttpClientError {
176    fn from(value: String) -> Self {
177        Self::Error(value)
178    }
179}
180
181/// An HTTP client that supports rate limiting and timeouts.
182///
183/// Built on `reqwest` for async I/O. Allows per-endpoint and default quotas
184/// through a rate limiter.
185///
186/// This struct is designed to handle HTTP requests efficiently, providing
187/// support for rate limiting, timeouts, and custom headers. The client is
188/// built on top of `reqwest` and can be used for both synchronous and
189/// asynchronous HTTP requests.
190#[derive(Clone)]
191#[cfg_attr(
192    feature = "python",
193    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
194)]
195pub struct HttpClient {
196    /// The underlying HTTP client used to make requests.
197    pub(crate) client: InnerHttpClient,
198    /// The rate limiter to control the request rate.
199    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
200}
201
202impl HttpClient {
203    /// Creates a new [`HttpClient`] instance.
204    #[must_use]
205    pub fn new(
206        headers: HashMap<String, String>,
207        header_keys: Vec<String>,
208        keyed_quotas: Vec<(String, Quota)>,
209        default_quota: Option<Quota>,
210    ) -> Self {
211        // Build default headers
212        let mut header_map = HeaderMap::new();
213        for (key, value) in headers {
214            let header_name = HeaderName::from_str(&key).expect("Invalid header name");
215            let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
216            header_map.insert(header_name, header_value);
217        }
218
219        let client = reqwest::Client::builder()
220            .default_headers(header_map)
221            .build()
222            .expect("Failed to build reqwest client");
223
224        let client = InnerHttpClient {
225            client,
226            header_keys: Arc::new(header_keys),
227        };
228        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
229
230        Self {
231            client,
232            rate_limiter,
233        }
234    }
235
236    /// Sends an HTTP request.
237    ///
238    /// - `method`: The [`Method`] to use (GET, POST, etc.).
239    /// - `url`: The target URL.
240    /// - `headers`: Additional headers for this request.
241    /// - `body`: Optional request body.
242    /// - `keys`: Rate-limit keys to control request frequency.
243    /// - `timeout_secs`: Optional request timeout in seconds.
244    ///
245    /// # Example
246    /// If requesting `/foo/bar`, pass rate-limit keys `["foo/bar", "foo"]`.
247    #[allow(clippy::too_many_arguments)]
248    pub async fn request(
249        &self,
250        method: Method,
251        url: String,
252        headers: Option<HashMap<String, String>>,
253        body: Option<Vec<u8>>,
254        keys: Option<Vec<String>>,
255        timeout_secs: Option<u64>,
256    ) -> Result<HttpResponse, HttpClientError> {
257        let rate_limiter = self.rate_limiter.clone();
258
259        rate_limiter.await_keys_ready(keys).await;
260        self.client
261            .send_request(method, url, headers, body, timeout_secs)
262            .await
263    }
264}
265
266/// Internal implementation backing [`HttpClient`].
267///
268/// The client is backed by a [`reqwest::Client`] which keeps connections alive and
269/// can be cloned cheaply. The client also has a list of header fields to
270/// extract from the response.
271///
272/// The client returns an [`HttpResponse`]. The client filters only the key value
273/// for the give `header_keys`.
274#[derive(Clone, Debug)]
275pub struct InnerHttpClient {
276    pub(crate) client: reqwest::Client,
277    pub(crate) header_keys: Arc<Vec<String>>,
278}
279
280impl InnerHttpClient {
281    /// Sends an HTTP request and returns an [`HttpResponse`].
282    ///
283    /// - `method`: The HTTP method (e.g. GET, POST).
284    /// - `url`: The target URL.
285    /// - `headers`: Extra headers to send.
286    /// - `body`: Optional request body.
287    /// - `timeout_secs`: Optional request timeout in seconds.
288    pub async fn send_request(
289        &self,
290        method: Method,
291        url: String,
292        headers: Option<HashMap<String, String>>,
293        body: Option<Vec<u8>>,
294        timeout_secs: Option<u64>,
295    ) -> Result<HttpResponse, HttpClientError> {
296        let headers = headers.unwrap_or_default();
297        let reqwest_url = Url::parse(url.as_str())
298            .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
299
300        let mut header_map = HeaderMap::new();
301        for (header_key, header_value) in &headers {
302            let key = HeaderName::from_bytes(header_key.as_bytes())
303                .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
304            let _ = header_map.insert(
305                key,
306                header_value
307                    .parse()
308                    .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
309            );
310        }
311
312        let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
313
314        if let Some(timeout_secs) = timeout_secs {
315            request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
316        }
317
318        let request = match body {
319            Some(b) => request_builder
320                .body(b)
321                .build()
322                .map_err(HttpClientError::from)?,
323            None => request_builder.build().map_err(HttpClientError::from)?,
324        };
325
326        tracing::trace!("{request:?}");
327
328        let response = self
329            .client
330            .execute(request)
331            .await
332            .map_err(HttpClientError::from)?;
333
334        self.to_response(response).await
335    }
336
337    /// Converts a `reqwest::Response` into an `HttpResponse`.
338    pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
339        tracing::trace!("{response:?}");
340
341        let headers: HashMap<String, String> = self
342            .header_keys
343            .iter()
344            .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
345            .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
346            .map(|(k, v)| (k.clone(), v.to_owned()))
347            .collect();
348        let status = HttpStatus::new(response.status());
349        let body = response.bytes().await.map_err(HttpClientError::from)?;
350
351        Ok(HttpResponse {
352            status,
353            headers,
354            body,
355        })
356    }
357}
358
359impl Default for InnerHttpClient {
360    /// Creates a new default [`InnerHttpClient`] instance.
361    ///
362    /// The default client is initialized with an empty list of header keys and a new `reqwest::Client`.
363    fn default() -> Self {
364        let client = reqwest::Client::new();
365        Self {
366            client,
367            header_keys: Default::default(),
368        }
369    }
370}
371
372////////////////////////////////////////////////////////////////////////////////
373// Tests
374////////////////////////////////////////////////////////////////////////////////
375#[cfg(test)]
376#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
377mod tests {
378    use std::net::{SocketAddr, TcpListener};
379
380    use axum::{
381        routing::{delete, get, patch, post},
382        serve, Router,
383    };
384    use http::status::StatusCode;
385
386    use super::*;
387
388    fn get_unique_port() -> u16 {
389        // Create a temporary TcpListener to get an available port
390        let listener =
391            TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
392        let port = listener.local_addr().unwrap().port();
393
394        // Close the listener to free up the port
395        drop(listener);
396
397        port
398    }
399
400    fn create_router() -> Router {
401        Router::new()
402            .route("/get", get(|| async { "hello-world!" }))
403            .route("/post", post(|| async { StatusCode::OK }))
404            .route("/patch", patch(|| async { StatusCode::OK }))
405            .route("/delete", delete(|| async { StatusCode::OK }))
406            .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
407            .route(
408                "/slow",
409                get(|| async {
410                    tokio::time::sleep(Duration::from_secs(2)).await;
411                    "Eventually responded"
412                }),
413            )
414    }
415
416    async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
417        let port = get_unique_port();
418        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
419            .await
420            .unwrap();
421        let addr = listener.local_addr().unwrap();
422
423        tokio::spawn(async move {
424            serve(listener, create_router()).await.unwrap();
425        });
426
427        Ok(addr)
428    }
429
430    #[tokio::test]
431    async fn test_get() {
432        let addr = start_test_server().await.unwrap();
433        let url = format!("http://{addr}");
434
435        let client = InnerHttpClient::default();
436        let response = client
437            .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
438            .await
439            .unwrap();
440
441        assert!(response.status.is_success());
442        assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
443    }
444
445    #[tokio::test]
446    async fn test_post() {
447        let addr = start_test_server().await.unwrap();
448        let url = format!("http://{addr}");
449
450        let client = InnerHttpClient::default();
451        let response = client
452            .send_request(
453                reqwest::Method::POST,
454                format!("{url}/post"),
455                None,
456                None,
457                None,
458            )
459            .await
460            .unwrap();
461
462        assert!(response.status.is_success());
463    }
464
465    #[tokio::test]
466    async fn test_post_with_body() {
467        let addr = start_test_server().await.unwrap();
468        let url = format!("http://{addr}");
469
470        let client = InnerHttpClient::default();
471
472        let mut body = HashMap::new();
473        body.insert(
474            "key1".to_string(),
475            serde_json::Value::String("value1".to_string()),
476        );
477        body.insert(
478            "key2".to_string(),
479            serde_json::Value::String("value2".to_string()),
480        );
481
482        let body_string = serde_json::to_string(&body).unwrap();
483        let body_bytes = body_string.into_bytes();
484
485        let response = client
486            .send_request(
487                reqwest::Method::POST,
488                format!("{url}/post"),
489                None,
490                Some(body_bytes),
491                None,
492            )
493            .await
494            .unwrap();
495
496        assert!(response.status.is_success());
497    }
498
499    #[tokio::test]
500    async fn test_patch() {
501        let addr = start_test_server().await.unwrap();
502        let url = format!("http://{addr}");
503
504        let client = InnerHttpClient::default();
505        let response = client
506            .send_request(
507                reqwest::Method::PATCH,
508                format!("{url}/patch"),
509                None,
510                None,
511                None,
512            )
513            .await
514            .unwrap();
515
516        assert!(response.status.is_success());
517    }
518
519    #[tokio::test]
520    async fn test_delete() {
521        let addr = start_test_server().await.unwrap();
522        let url = format!("http://{addr}");
523
524        let client = InnerHttpClient::default();
525        let response = client
526            .send_request(
527                reqwest::Method::DELETE,
528                format!("{url}/delete"),
529                None,
530                None,
531                None,
532            )
533            .await
534            .unwrap();
535
536        assert!(response.status.is_success());
537    }
538
539    #[tokio::test]
540    async fn test_not_found() {
541        let addr = start_test_server().await.unwrap();
542        let url = format!("http://{addr}/notfound");
543        let client = InnerHttpClient::default();
544
545        let response = client
546            .send_request(reqwest::Method::GET, url, None, None, None)
547            .await
548            .unwrap();
549
550        assert!(response.status.is_client_error());
551        assert_eq!(response.status.as_u16(), 404);
552    }
553
554    #[tokio::test]
555    async fn test_timeout() {
556        let addr = start_test_server().await.unwrap();
557        let url = format!("http://{addr}/slow");
558        let client = InnerHttpClient::default();
559
560        // We'll set a 1-second timeout for a route that sleeps 2 seconds
561        let result = client
562            .send_request(reqwest::Method::GET, url, None, None, Some(1))
563            .await;
564
565        match result {
566            Err(HttpClientError::TimeoutError(msg)) => {
567                println!("Got expected timeout error: {msg}");
568            }
569            Err(other) => panic!("Expected a timeout error, got: {other:?}"),
570            Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
571        }
572    }
573}