nautilus_network/
http.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A high-performance HTTP client implementation.
17
18use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23    Method, Response, Url,
24    header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29/// Represents a HTTP status code.
30///
31/// Wraps [`http::StatusCode`] to expose a Python-compatible type and reuse
32/// its validation and convenience methods.
33#[derive(Clone, Debug)]
34#[cfg_attr(
35    feature = "python",
36    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39    inner: StatusCode,
40}
41
42impl HttpStatus {
43    /// Create a new [`HttpStatus`] instance from a given [`StatusCode`].
44    #[must_use]
45    pub const fn new(code: StatusCode) -> Self {
46        Self { inner: code }
47    }
48
49    /// Attempts to construct a [`HttpStatus`] from a `u16`.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if the code is not in the valid `100..999` range.
54    pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
55        Ok(Self {
56            inner: StatusCode::from_u16(code)?,
57        })
58    }
59
60    /// Returns the status code as a `u16` (e.g., `200` for OK).
61    #[inline]
62    #[must_use]
63    pub const fn as_u16(&self) -> u16 {
64        self.inner.as_u16()
65    }
66
67    /// Returns the three-digit ASCII representation of this status (e.g., `"200"`).
68    #[inline]
69    #[must_use]
70    pub fn as_str(&self) -> &str {
71        self.inner.as_str()
72    }
73
74    /// Checks if this status is in the 1xx (informational) range.
75    #[inline]
76    #[must_use]
77    pub fn is_informational(&self) -> bool {
78        self.inner.is_informational()
79    }
80
81    /// Checks if this status is in the 2xx (success) range.
82    #[inline]
83    #[must_use]
84    pub fn is_success(&self) -> bool {
85        self.inner.is_success()
86    }
87
88    /// Checks if this status is in the 3xx (redirection) range.
89    #[inline]
90    #[must_use]
91    pub fn is_redirection(&self) -> bool {
92        self.inner.is_redirection()
93    }
94
95    /// Checks if this status is in the 4xx (client error) range.
96    #[inline]
97    #[must_use]
98    pub fn is_client_error(&self) -> bool {
99        self.inner.is_client_error()
100    }
101
102    /// Checks if this status is in the 5xx (server error) range.
103    #[inline]
104    #[must_use]
105    pub fn is_server_error(&self) -> bool {
106        self.inner.is_server_error()
107    }
108}
109
110/// Represents the HTTP methods supported by the `HttpClient`.
111#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112#[cfg_attr(
113    feature = "python",
114    pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
115)]
116pub enum HttpMethod {
117    GET,
118    POST,
119    PUT,
120    DELETE,
121    PATCH,
122}
123
124#[allow(clippy::from_over_into)]
125impl Into<Method> for HttpMethod {
126    fn into(self) -> Method {
127        match self {
128            Self::GET => Method::GET,
129            Self::POST => Method::POST,
130            Self::PUT => Method::PUT,
131            Self::DELETE => Method::DELETE,
132            Self::PATCH => Method::PATCH,
133        }
134    }
135}
136
137/// Represents the response from an HTTP request.
138///
139/// This struct encapsulates the status, headers, and body of an HTTP response,
140/// providing easy access to the key components of the response.
141#[derive(Clone, Debug)]
142#[cfg_attr(
143    feature = "python",
144    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
145)]
146pub struct HttpResponse {
147    /// The HTTP status code.
148    pub status: HttpStatus,
149    /// The response headers as a map of key-value pairs.
150    pub headers: HashMap<String, String>,
151    /// The raw response body.
152    pub body: Bytes,
153}
154
155/// Errors returned by the HTTP client.
156///
157/// Includes generic transport errors and timeouts.
158#[derive(thiserror::Error, Debug)]
159pub enum HttpClientError {
160    #[error("HTTP error occurred: {0}")]
161    Error(String),
162
163    #[error("HTTP request timed out: {0}")]
164    TimeoutError(String),
165}
166
167impl From<reqwest::Error> for HttpClientError {
168    fn from(source: reqwest::Error) -> Self {
169        if source.is_timeout() {
170            Self::TimeoutError(source.to_string())
171        } else {
172            Self::Error(source.to_string())
173        }
174    }
175}
176
177impl From<String> for HttpClientError {
178    fn from(value: String) -> Self {
179        Self::Error(value)
180    }
181}
182
183/// An HTTP client that supports rate limiting and timeouts.
184///
185/// Built on `reqwest` for async I/O. Allows per-endpoint and default quotas
186/// through a rate limiter.
187///
188/// This struct is designed to handle HTTP requests efficiently, providing
189/// support for rate limiting, timeouts, and custom headers. The client is
190/// built on top of `reqwest` and can be used for both synchronous and
191/// asynchronous HTTP requests.
192#[derive(Clone)]
193#[cfg_attr(
194    feature = "python",
195    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
196)]
197pub struct HttpClient {
198    /// The underlying HTTP client used to make requests.
199    pub(crate) client: InnerHttpClient,
200    /// The rate limiter to control the request rate.
201    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
202}
203
204impl HttpClient {
205    /// Creates a new [`HttpClient`] instance.
206    #[must_use]
207    pub fn new(
208        headers: HashMap<String, String>,
209        header_keys: Vec<String>,
210        keyed_quotas: Vec<(String, Quota)>,
211        default_quota: Option<Quota>,
212        timeout_secs: Option<u64>,
213    ) -> Self {
214        // Build default headers
215        let mut header_map = HeaderMap::new();
216        for (key, value) in headers {
217            let header_name = HeaderName::from_str(&key).expect("Invalid header name");
218            let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
219            header_map.insert(header_name, header_value);
220        }
221
222        let mut client_builder = reqwest::Client::builder().default_headers(header_map);
223        if let Some(timeout_secs) = timeout_secs {
224            client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
225        }
226
227        let client = client_builder
228            .build()
229            .expect("Failed to build reqwest client");
230
231        let client = InnerHttpClient {
232            client,
233            header_keys: Arc::new(header_keys),
234        };
235        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
236
237        Self {
238            client,
239            rate_limiter,
240        }
241    }
242
243    /// Sends an HTTP request.
244    ///
245    /// - `method`: The [`Method`] to use (GET, POST, etc.).
246    /// - `url`: The target URL.
247    /// - `headers`: Additional headers for this request.
248    /// - `body`: Optional request body.
249    /// - `keys`: Rate-limit keys to control request frequency.
250    /// - `timeout_secs`: Optional request timeout in seconds.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if unable to send request or times out.
255    ///
256    /// # Examples
257    ///
258    /// If requesting `/foo/bar`, pass rate-limit keys `["foo/bar", "foo"]`.
259    #[allow(clippy::too_many_arguments)]
260    pub async fn request(
261        &self,
262        method: Method,
263        url: String,
264        headers: Option<HashMap<String, String>>,
265        body: Option<Vec<u8>>,
266        keys: Option<Vec<String>>,
267        timeout_secs: Option<u64>,
268    ) -> Result<HttpResponse, HttpClientError> {
269        let rate_limiter = self.rate_limiter.clone();
270
271        rate_limiter.await_keys_ready(keys).await;
272        self.client
273            .send_request(method, url, headers, body, timeout_secs)
274            .await
275    }
276}
277
278/// Internal implementation backing [`HttpClient`].
279///
280/// The client is backed by a [`reqwest::Client`] which keeps connections alive and
281/// can be cloned cheaply. The client also has a list of header fields to
282/// extract from the response.
283///
284/// The client returns an [`HttpResponse`]. The client filters only the key value
285/// for the give `header_keys`.
286#[derive(Clone, Debug)]
287pub struct InnerHttpClient {
288    pub(crate) client: reqwest::Client,
289    pub(crate) header_keys: Arc<Vec<String>>,
290}
291
292impl InnerHttpClient {
293    /// Sends an HTTP request and returns an [`HttpResponse`].
294    ///
295    /// - `method`: The HTTP method (e.g. GET, POST).
296    /// - `url`: The target URL.
297    /// - `headers`: Extra headers to send.
298    /// - `body`: Optional request body.
299    /// - `timeout_secs`: Optional request timeout in seconds.
300    ///
301    /// # Errors
302    ///
303    /// Returns an error if unable to send request or times out.
304    pub async fn send_request(
305        &self,
306        method: Method,
307        url: String,
308        headers: Option<HashMap<String, String>>,
309        body: Option<Vec<u8>>,
310        timeout_secs: Option<u64>,
311    ) -> Result<HttpResponse, HttpClientError> {
312        let headers = headers.unwrap_or_default();
313        let reqwest_url = Url::parse(url.as_str())
314            .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
315
316        let mut header_map = HeaderMap::new();
317        for (header_key, header_value) in &headers {
318            let key = HeaderName::from_bytes(header_key.as_bytes())
319                .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
320            let _ = header_map.insert(
321                key,
322                header_value
323                    .parse()
324                    .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
325            );
326        }
327
328        let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
329
330        if let Some(timeout_secs) = timeout_secs {
331            request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
332        }
333
334        let request = match body {
335            Some(b) => request_builder
336                .body(b)
337                .build()
338                .map_err(HttpClientError::from)?,
339            None => request_builder.build().map_err(HttpClientError::from)?,
340        };
341
342        tracing::trace!("{request:?}");
343
344        let response = self
345            .client
346            .execute(request)
347            .await
348            .map_err(HttpClientError::from)?;
349
350        self.to_response(response).await
351    }
352
353    /// Converts a `reqwest::Response` into an `HttpResponse`.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if unable to send request or times out.
358    pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
359        tracing::trace!("{response:?}");
360
361        let headers: HashMap<String, String> = self
362            .header_keys
363            .iter()
364            .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
365            .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
366            .map(|(k, v)| (k.clone(), v.to_owned()))
367            .collect();
368        let status = HttpStatus::new(response.status());
369        let body = response.bytes().await.map_err(HttpClientError::from)?;
370
371        Ok(HttpResponse {
372            status,
373            headers,
374            body,
375        })
376    }
377}
378
379impl Default for InnerHttpClient {
380    /// Creates a new default [`InnerHttpClient`] instance.
381    ///
382    /// The default client is initialized with an empty list of header keys and a new `reqwest::Client`.
383    fn default() -> Self {
384        let client = reqwest::Client::new();
385        Self {
386            client,
387            header_keys: Default::default(),
388        }
389    }
390}
391
392////////////////////////////////////////////////////////////////////////////////
393// Tests
394////////////////////////////////////////////////////////////////////////////////
395#[cfg(test)]
396#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
397mod tests {
398    use std::net::{SocketAddr, TcpListener};
399
400    use axum::{
401        Router,
402        routing::{delete, get, patch, post},
403        serve,
404    };
405    use http::status::StatusCode;
406
407    use super::*;
408
409    fn get_unique_port() -> u16 {
410        // Create a temporary TcpListener to get an available port
411        let listener =
412            TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
413        let port = listener.local_addr().unwrap().port();
414
415        // Close the listener to free up the port
416        drop(listener);
417
418        port
419    }
420
421    fn create_router() -> Router {
422        Router::new()
423            .route("/get", get(|| async { "hello-world!" }))
424            .route("/post", post(|| async { StatusCode::OK }))
425            .route("/patch", patch(|| async { StatusCode::OK }))
426            .route("/delete", delete(|| async { StatusCode::OK }))
427            .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
428            .route(
429                "/slow",
430                get(|| async {
431                    tokio::time::sleep(Duration::from_secs(2)).await;
432                    "Eventually responded"
433                }),
434            )
435    }
436
437    async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
438        let port = get_unique_port();
439        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
440            .await
441            .unwrap();
442        let addr = listener.local_addr().unwrap();
443
444        tokio::spawn(async move {
445            serve(listener, create_router()).await.unwrap();
446        });
447
448        Ok(addr)
449    }
450
451    #[tokio::test]
452    async fn test_get() {
453        let addr = start_test_server().await.unwrap();
454        let url = format!("http://{addr}");
455
456        let client = InnerHttpClient::default();
457        let response = client
458            .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
459            .await
460            .unwrap();
461
462        assert!(response.status.is_success());
463        assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
464    }
465
466    #[tokio::test]
467    async fn test_post() {
468        let addr = start_test_server().await.unwrap();
469        let url = format!("http://{addr}");
470
471        let client = InnerHttpClient::default();
472        let response = client
473            .send_request(
474                reqwest::Method::POST,
475                format!("{url}/post"),
476                None,
477                None,
478                None,
479            )
480            .await
481            .unwrap();
482
483        assert!(response.status.is_success());
484    }
485
486    #[tokio::test]
487    async fn test_post_with_body() {
488        let addr = start_test_server().await.unwrap();
489        let url = format!("http://{addr}");
490
491        let client = InnerHttpClient::default();
492
493        let mut body = HashMap::new();
494        body.insert(
495            "key1".to_string(),
496            serde_json::Value::String("value1".to_string()),
497        );
498        body.insert(
499            "key2".to_string(),
500            serde_json::Value::String("value2".to_string()),
501        );
502
503        let body_string = serde_json::to_string(&body).unwrap();
504        let body_bytes = body_string.into_bytes();
505
506        let response = client
507            .send_request(
508                reqwest::Method::POST,
509                format!("{url}/post"),
510                None,
511                Some(body_bytes),
512                None,
513            )
514            .await
515            .unwrap();
516
517        assert!(response.status.is_success());
518    }
519
520    #[tokio::test]
521    async fn test_patch() {
522        let addr = start_test_server().await.unwrap();
523        let url = format!("http://{addr}");
524
525        let client = InnerHttpClient::default();
526        let response = client
527            .send_request(
528                reqwest::Method::PATCH,
529                format!("{url}/patch"),
530                None,
531                None,
532                None,
533            )
534            .await
535            .unwrap();
536
537        assert!(response.status.is_success());
538    }
539
540    #[tokio::test]
541    async fn test_delete() {
542        let addr = start_test_server().await.unwrap();
543        let url = format!("http://{addr}");
544
545        let client = InnerHttpClient::default();
546        let response = client
547            .send_request(
548                reqwest::Method::DELETE,
549                format!("{url}/delete"),
550                None,
551                None,
552                None,
553            )
554            .await
555            .unwrap();
556
557        assert!(response.status.is_success());
558    }
559
560    #[tokio::test]
561    async fn test_not_found() {
562        let addr = start_test_server().await.unwrap();
563        let url = format!("http://{addr}/notfound");
564        let client = InnerHttpClient::default();
565
566        let response = client
567            .send_request(reqwest::Method::GET, url, None, None, None)
568            .await
569            .unwrap();
570
571        assert!(response.status.is_client_error());
572        assert_eq!(response.status.as_u16(), 404);
573    }
574
575    #[tokio::test]
576    async fn test_timeout() {
577        let addr = start_test_server().await.unwrap();
578        let url = format!("http://{addr}/slow");
579        let client = InnerHttpClient::default();
580
581        // We'll set a 1-second timeout for a route that sleeps 2 seconds
582        let result = client
583            .send_request(reqwest::Method::GET, url, None, None, Some(1))
584            .await;
585
586        match result {
587            Err(HttpClientError::TimeoutError(msg)) => {
588                println!("Got expected timeout error: {msg}");
589            }
590            Err(other) => panic!("Expected a timeout error, got: {other:?}"),
591            Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
592        }
593    }
594}