nautilus_network/
http.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A high-performance HTTP client implementation.
17
18use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23    Method, Response, Url,
24    header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29/// Represents a HTTP status code.
30///
31/// Wraps [`http::StatusCode`] to expose a Python-compatible type and reuse
32/// its validation and convenience methods.
33#[derive(Clone, Debug)]
34#[cfg_attr(
35    feature = "python",
36    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39    inner: StatusCode,
40}
41
42impl HttpStatus {
43    /// Create a new [`HttpStatus`] instance from a given [`StatusCode`].
44    #[must_use]
45    pub const fn new(code: StatusCode) -> Self {
46        Self { inner: code }
47    }
48
49    /// Attempts to construct a [`HttpStatus`] from a `u16`.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if the code is not in the valid `100..999` range.
54    pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
55        Ok(Self {
56            inner: StatusCode::from_u16(code)?,
57        })
58    }
59
60    /// Returns the status code as a `u16` (e.g., `200` for OK).
61    #[inline]
62    #[must_use]
63    pub const fn as_u16(&self) -> u16 {
64        self.inner.as_u16()
65    }
66
67    /// Returns the three-digit ASCII representation of this status (e.g., `"200"`).
68    #[inline]
69    #[must_use]
70    pub fn as_str(&self) -> &str {
71        self.inner.as_str()
72    }
73
74    /// Checks if this status is in the 1xx (informational) range.
75    #[inline]
76    #[must_use]
77    pub fn is_informational(&self) -> bool {
78        self.inner.is_informational()
79    }
80
81    /// Checks if this status is in the 2xx (success) range.
82    #[inline]
83    #[must_use]
84    pub fn is_success(&self) -> bool {
85        self.inner.is_success()
86    }
87
88    /// Checks if this status is in the 3xx (redirection) range.
89    #[inline]
90    #[must_use]
91    pub fn is_redirection(&self) -> bool {
92        self.inner.is_redirection()
93    }
94
95    /// Checks if this status is in the 4xx (client error) range.
96    #[inline]
97    #[must_use]
98    pub fn is_client_error(&self) -> bool {
99        self.inner.is_client_error()
100    }
101
102    /// Checks if this status is in the 5xx (server error) range.
103    #[inline]
104    #[must_use]
105    pub fn is_server_error(&self) -> bool {
106        self.inner.is_server_error()
107    }
108}
109
110/// Represents the HTTP methods supported by the `HttpClient`.
111#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112#[cfg_attr(
113    feature = "python",
114    pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
115)]
116pub enum HttpMethod {
117    GET,
118    POST,
119    PUT,
120    DELETE,
121    PATCH,
122}
123
124#[allow(clippy::from_over_into)]
125impl Into<Method> for HttpMethod {
126    fn into(self) -> Method {
127        match self {
128            Self::GET => Method::GET,
129            Self::POST => Method::POST,
130            Self::PUT => Method::PUT,
131            Self::DELETE => Method::DELETE,
132            Self::PATCH => Method::PATCH,
133        }
134    }
135}
136
137/// Represents the response from an HTTP request.
138///
139/// This struct encapsulates the status, headers, and body of an HTTP response,
140/// providing easy access to the key components of the response.
141#[derive(Clone, Debug)]
142#[cfg_attr(
143    feature = "python",
144    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
145)]
146pub struct HttpResponse {
147    /// The HTTP status code.
148    pub status: HttpStatus,
149    /// The response headers as a map of key-value pairs.
150    pub headers: HashMap<String, String>,
151    /// The raw response body.
152    pub body: Bytes,
153}
154
155/// Errors returned by the HTTP client.
156///
157/// Includes generic transport errors and timeouts.
158#[derive(thiserror::Error, Debug)]
159pub enum HttpClientError {
160    #[error("HTTP error occurred: {0}")]
161    Error(String),
162
163    #[error("HTTP request timed out: {0}")]
164    TimeoutError(String),
165}
166
167impl From<reqwest::Error> for HttpClientError {
168    fn from(source: reqwest::Error) -> Self {
169        if source.is_timeout() {
170            Self::TimeoutError(source.to_string())
171        } else {
172            Self::Error(source.to_string())
173        }
174    }
175}
176
177impl From<String> for HttpClientError {
178    fn from(value: String) -> Self {
179        Self::Error(value)
180    }
181}
182
183/// An HTTP client that supports rate limiting and timeouts.
184///
185/// Built on `reqwest` for async I/O. Allows per-endpoint and default quotas
186/// through a rate limiter.
187///
188/// This struct is designed to handle HTTP requests efficiently, providing
189/// support for rate limiting, timeouts, and custom headers. The client is
190/// built on top of `reqwest` and can be used for both synchronous and
191/// asynchronous HTTP requests.
192#[derive(Clone)]
193#[cfg_attr(
194    feature = "python",
195    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
196)]
197pub struct HttpClient {
198    /// The underlying HTTP client used to make requests.
199    pub(crate) client: InnerHttpClient,
200    /// The rate limiter to control the request rate.
201    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
202}
203
204impl HttpClient {
205    /// Creates a new [`HttpClient`] instance.
206    #[must_use]
207    pub fn new(
208        headers: HashMap<String, String>,
209        header_keys: Vec<String>,
210        keyed_quotas: Vec<(String, Quota)>,
211        default_quota: Option<Quota>,
212    ) -> Self {
213        // Build default headers
214        let mut header_map = HeaderMap::new();
215        for (key, value) in headers {
216            let header_name = HeaderName::from_str(&key).expect("Invalid header name");
217            let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
218            header_map.insert(header_name, header_value);
219        }
220
221        let client = reqwest::Client::builder()
222            .default_headers(header_map)
223            .build()
224            .expect("Failed to build reqwest client");
225
226        let client = InnerHttpClient {
227            client,
228            header_keys: Arc::new(header_keys),
229        };
230        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
231
232        Self {
233            client,
234            rate_limiter,
235        }
236    }
237
238    /// Sends an HTTP request.
239    ///
240    /// - `method`: The [`Method`] to use (GET, POST, etc.).
241    /// - `url`: The target URL.
242    /// - `headers`: Additional headers for this request.
243    /// - `body`: Optional request body.
244    /// - `keys`: Rate-limit keys to control request frequency.
245    /// - `timeout_secs`: Optional request timeout in seconds.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if unable to send request or times out.
250    ///
251    /// # Examples
252    ///
253    /// If requesting `/foo/bar`, pass rate-limit keys `["foo/bar", "foo"]`.
254    #[allow(clippy::too_many_arguments)]
255    pub async fn request(
256        &self,
257        method: Method,
258        url: String,
259        headers: Option<HashMap<String, String>>,
260        body: Option<Vec<u8>>,
261        keys: Option<Vec<String>>,
262        timeout_secs: Option<u64>,
263    ) -> Result<HttpResponse, HttpClientError> {
264        let rate_limiter = self.rate_limiter.clone();
265
266        rate_limiter.await_keys_ready(keys).await;
267        self.client
268            .send_request(method, url, headers, body, timeout_secs)
269            .await
270    }
271}
272
273/// Internal implementation backing [`HttpClient`].
274///
275/// The client is backed by a [`reqwest::Client`] which keeps connections alive and
276/// can be cloned cheaply. The client also has a list of header fields to
277/// extract from the response.
278///
279/// The client returns an [`HttpResponse`]. The client filters only the key value
280/// for the give `header_keys`.
281#[derive(Clone, Debug)]
282pub struct InnerHttpClient {
283    pub(crate) client: reqwest::Client,
284    pub(crate) header_keys: Arc<Vec<String>>,
285}
286
287impl InnerHttpClient {
288    /// Sends an HTTP request and returns an [`HttpResponse`].
289    ///
290    /// - `method`: The HTTP method (e.g. GET, POST).
291    /// - `url`: The target URL.
292    /// - `headers`: Extra headers to send.
293    /// - `body`: Optional request body.
294    /// - `timeout_secs`: Optional request timeout in seconds.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if unable to send request or times out.
299    pub async fn send_request(
300        &self,
301        method: Method,
302        url: String,
303        headers: Option<HashMap<String, String>>,
304        body: Option<Vec<u8>>,
305        timeout_secs: Option<u64>,
306    ) -> Result<HttpResponse, HttpClientError> {
307        let headers = headers.unwrap_or_default();
308        let reqwest_url = Url::parse(url.as_str())
309            .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
310
311        let mut header_map = HeaderMap::new();
312        for (header_key, header_value) in &headers {
313            let key = HeaderName::from_bytes(header_key.as_bytes())
314                .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
315            let _ = header_map.insert(
316                key,
317                header_value
318                    .parse()
319                    .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
320            );
321        }
322
323        let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
324
325        if let Some(timeout_secs) = timeout_secs {
326            request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
327        }
328
329        let request = match body {
330            Some(b) => request_builder
331                .body(b)
332                .build()
333                .map_err(HttpClientError::from)?,
334            None => request_builder.build().map_err(HttpClientError::from)?,
335        };
336
337        tracing::trace!("{request:?}");
338
339        let response = self
340            .client
341            .execute(request)
342            .await
343            .map_err(HttpClientError::from)?;
344
345        self.to_response(response).await
346    }
347
348    /// Converts a `reqwest::Response` into an `HttpResponse`.
349    ///
350    /// # Errors
351    ///
352    /// Returns an error if unable to send request or times out.
353    pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
354        tracing::trace!("{response:?}");
355
356        let headers: HashMap<String, String> = self
357            .header_keys
358            .iter()
359            .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
360            .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
361            .map(|(k, v)| (k.clone(), v.to_owned()))
362            .collect();
363        let status = HttpStatus::new(response.status());
364        let body = response.bytes().await.map_err(HttpClientError::from)?;
365
366        Ok(HttpResponse {
367            status,
368            headers,
369            body,
370        })
371    }
372}
373
374impl Default for InnerHttpClient {
375    /// Creates a new default [`InnerHttpClient`] instance.
376    ///
377    /// The default client is initialized with an empty list of header keys and a new `reqwest::Client`.
378    fn default() -> Self {
379        let client = reqwest::Client::new();
380        Self {
381            client,
382            header_keys: Default::default(),
383        }
384    }
385}
386
387////////////////////////////////////////////////////////////////////////////////
388// Tests
389////////////////////////////////////////////////////////////////////////////////
390#[cfg(test)]
391#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
392mod tests {
393    use std::net::{SocketAddr, TcpListener};
394
395    use axum::{
396        Router,
397        routing::{delete, get, patch, post},
398        serve,
399    };
400    use http::status::StatusCode;
401
402    use super::*;
403
404    fn get_unique_port() -> u16 {
405        // Create a temporary TcpListener to get an available port
406        let listener =
407            TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
408        let port = listener.local_addr().unwrap().port();
409
410        // Close the listener to free up the port
411        drop(listener);
412
413        port
414    }
415
416    fn create_router() -> Router {
417        Router::new()
418            .route("/get", get(|| async { "hello-world!" }))
419            .route("/post", post(|| async { StatusCode::OK }))
420            .route("/patch", patch(|| async { StatusCode::OK }))
421            .route("/delete", delete(|| async { StatusCode::OK }))
422            .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
423            .route(
424                "/slow",
425                get(|| async {
426                    tokio::time::sleep(Duration::from_secs(2)).await;
427                    "Eventually responded"
428                }),
429            )
430    }
431
432    async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
433        let port = get_unique_port();
434        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
435            .await
436            .unwrap();
437        let addr = listener.local_addr().unwrap();
438
439        tokio::spawn(async move {
440            serve(listener, create_router()).await.unwrap();
441        });
442
443        Ok(addr)
444    }
445
446    #[tokio::test]
447    async fn test_get() {
448        let addr = start_test_server().await.unwrap();
449        let url = format!("http://{addr}");
450
451        let client = InnerHttpClient::default();
452        let response = client
453            .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
454            .await
455            .unwrap();
456
457        assert!(response.status.is_success());
458        assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
459    }
460
461    #[tokio::test]
462    async fn test_post() {
463        let addr = start_test_server().await.unwrap();
464        let url = format!("http://{addr}");
465
466        let client = InnerHttpClient::default();
467        let response = client
468            .send_request(
469                reqwest::Method::POST,
470                format!("{url}/post"),
471                None,
472                None,
473                None,
474            )
475            .await
476            .unwrap();
477
478        assert!(response.status.is_success());
479    }
480
481    #[tokio::test]
482    async fn test_post_with_body() {
483        let addr = start_test_server().await.unwrap();
484        let url = format!("http://{addr}");
485
486        let client = InnerHttpClient::default();
487
488        let mut body = HashMap::new();
489        body.insert(
490            "key1".to_string(),
491            serde_json::Value::String("value1".to_string()),
492        );
493        body.insert(
494            "key2".to_string(),
495            serde_json::Value::String("value2".to_string()),
496        );
497
498        let body_string = serde_json::to_string(&body).unwrap();
499        let body_bytes = body_string.into_bytes();
500
501        let response = client
502            .send_request(
503                reqwest::Method::POST,
504                format!("{url}/post"),
505                None,
506                Some(body_bytes),
507                None,
508            )
509            .await
510            .unwrap();
511
512        assert!(response.status.is_success());
513    }
514
515    #[tokio::test]
516    async fn test_patch() {
517        let addr = start_test_server().await.unwrap();
518        let url = format!("http://{addr}");
519
520        let client = InnerHttpClient::default();
521        let response = client
522            .send_request(
523                reqwest::Method::PATCH,
524                format!("{url}/patch"),
525                None,
526                None,
527                None,
528            )
529            .await
530            .unwrap();
531
532        assert!(response.status.is_success());
533    }
534
535    #[tokio::test]
536    async fn test_delete() {
537        let addr = start_test_server().await.unwrap();
538        let url = format!("http://{addr}");
539
540        let client = InnerHttpClient::default();
541        let response = client
542            .send_request(
543                reqwest::Method::DELETE,
544                format!("{url}/delete"),
545                None,
546                None,
547                None,
548            )
549            .await
550            .unwrap();
551
552        assert!(response.status.is_success());
553    }
554
555    #[tokio::test]
556    async fn test_not_found() {
557        let addr = start_test_server().await.unwrap();
558        let url = format!("http://{addr}/notfound");
559        let client = InnerHttpClient::default();
560
561        let response = client
562            .send_request(reqwest::Method::GET, url, None, None, None)
563            .await
564            .unwrap();
565
566        assert!(response.status.is_client_error());
567        assert_eq!(response.status.as_u16(), 404);
568    }
569
570    #[tokio::test]
571    async fn test_timeout() {
572        let addr = start_test_server().await.unwrap();
573        let url = format!("http://{addr}/slow");
574        let client = InnerHttpClient::default();
575
576        // We'll set a 1-second timeout for a route that sleeps 2 seconds
577        let result = client
578            .send_request(reqwest::Method::GET, url, None, None, Some(1))
579            .await;
580
581        match result {
582            Err(HttpClientError::TimeoutError(msg)) => {
583                println!("Got expected timeout error: {msg}");
584            }
585            Err(other) => panic!("Expected a timeout error, got: {other:?}"),
586            Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
587        }
588    }
589}