nautilus_network/
http.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A high-performance HTTP client implementation.
17
18use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23    Method, Response, Url,
24    header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29/// Represents a HTTP status code.
30///
31/// Wraps [`http::StatusCode`] to expose a Python-compatible type and reuse
32/// its validation and convenience methods.
33#[derive(Clone, Debug)]
34#[cfg_attr(
35    feature = "python",
36    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39    inner: StatusCode,
40}
41
42impl HttpStatus {
43    /// Create a new [`HttpStatus`] instance from a given [`StatusCode`].
44    #[must_use]
45    pub const fn new(code: StatusCode) -> Self {
46        Self { inner: code }
47    }
48
49    /// Returns the status code as a `u16` (e.g., `200` for OK).
50    #[inline]
51    #[must_use]
52    pub const fn as_u16(&self) -> u16 {
53        self.inner.as_u16()
54    }
55
56    /// Returns the three-digit ASCII representation of this status (e.g., `"200"`).
57    #[inline]
58    #[must_use]
59    pub fn as_str(&self) -> &str {
60        self.inner.as_str()
61    }
62
63    /// Checks if this status is in the 1xx (informational) range.
64    #[inline]
65    #[must_use]
66    pub fn is_informational(&self) -> bool {
67        self.inner.is_informational()
68    }
69
70    /// Checks if this status is in the 2xx (success) range.
71    #[inline]
72    #[must_use]
73    pub fn is_success(&self) -> bool {
74        self.inner.is_success()
75    }
76
77    /// Checks if this status is in the 3xx (redirection) range.
78    #[inline]
79    #[must_use]
80    pub fn is_redirection(&self) -> bool {
81        self.inner.is_redirection()
82    }
83
84    /// Checks if this status is in the 4xx (client error) range.
85    #[inline]
86    #[must_use]
87    pub fn is_client_error(&self) -> bool {
88        self.inner.is_client_error()
89    }
90
91    /// Checks if this status is in the 5xx (server error) range.
92    #[inline]
93    #[must_use]
94    pub fn is_server_error(&self) -> bool {
95        self.inner.is_server_error()
96    }
97}
98
99impl TryFrom<u16> for HttpStatus {
100    type Error = InvalidStatusCode;
101
102    /// Attempts to construct a [`HttpStatus`] from a `u16`.
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if the code is not in the valid `100..999` range.
107    fn try_from(code: u16) -> Result<Self, Self::Error> {
108        Ok(Self {
109            inner: StatusCode::from_u16(code)?,
110        })
111    }
112}
113
114/// Represents the HTTP methods supported by the `HttpClient`.
115#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116#[cfg_attr(
117    feature = "python",
118    pyo3::pyclass(eq, eq_int, module = "nautilus_trader.core.nautilus_pyo3.network")
119)]
120pub enum HttpMethod {
121    GET,
122    POST,
123    PUT,
124    DELETE,
125    PATCH,
126}
127
128impl From<HttpMethod> for Method {
129    fn from(value: HttpMethod) -> Self {
130        match value {
131            HttpMethod::GET => Self::GET,
132            HttpMethod::POST => Self::POST,
133            HttpMethod::PUT => Self::PUT,
134            HttpMethod::DELETE => Self::DELETE,
135            HttpMethod::PATCH => Self::PATCH,
136        }
137    }
138}
139
140/// Represents the response from an HTTP request.
141///
142/// This struct encapsulates the status, headers, and body of an HTTP response,
143/// providing easy access to the key components of the response.
144#[derive(Clone, Debug)]
145#[cfg_attr(
146    feature = "python",
147    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
148)]
149pub struct HttpResponse {
150    /// The HTTP status code.
151    pub status: HttpStatus,
152    /// The response headers as a map of key-value pairs.
153    pub headers: HashMap<String, String>,
154    /// The raw response body.
155    pub body: Bytes,
156}
157
158/// Errors returned by the HTTP client.
159///
160/// Includes generic transport errors and timeouts.
161#[derive(thiserror::Error, Debug)]
162pub enum HttpClientError {
163    #[error("HTTP error occurred: {0}")]
164    Error(String),
165
166    #[error("HTTP request timed out: {0}")]
167    TimeoutError(String),
168}
169
170impl From<reqwest::Error> for HttpClientError {
171    fn from(source: reqwest::Error) -> Self {
172        if source.is_timeout() {
173            Self::TimeoutError(source.to_string())
174        } else {
175            Self::Error(source.to_string())
176        }
177    }
178}
179
180impl From<String> for HttpClientError {
181    fn from(value: String) -> Self {
182        Self::Error(value)
183    }
184}
185
186/// An HTTP client that supports rate limiting and timeouts.
187///
188/// Built on `reqwest` for async I/O. Allows per-endpoint and default quotas
189/// through a rate limiter.
190///
191/// This struct is designed to handle HTTP requests efficiently, providing
192/// support for rate limiting, timeouts, and custom headers. The client is
193/// built on top of `reqwest` and can be used for both synchronous and
194/// asynchronous HTTP requests.
195#[derive(Clone, Debug)]
196#[cfg_attr(
197    feature = "python",
198    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
199)]
200pub struct HttpClient {
201    /// The underlying HTTP client used to make requests.
202    pub(crate) client: InnerHttpClient,
203    /// The rate limiter to control the request rate.
204    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
205}
206
207impl HttpClient {
208    /// Creates a new [`HttpClient`] instance.
209    ///
210    /// # Panics
211    ///
212    /// Panics if any header key or value is invalid, or if building the underlying `reqwest::Client` fails.
213    #[must_use]
214    pub fn new(
215        headers: HashMap<String, String>,
216        header_keys: Vec<String>,
217        keyed_quotas: Vec<(String, Quota)>,
218        default_quota: Option<Quota>,
219        timeout_secs: Option<u64>,
220    ) -> Self {
221        // Build default headers
222        let mut header_map = HeaderMap::new();
223        for (key, value) in headers {
224            let header_name = HeaderName::from_str(&key).expect("Invalid header name");
225            let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
226            header_map.insert(header_name, header_value);
227        }
228
229        let mut client_builder = reqwest::Client::builder().default_headers(header_map);
230        if let Some(timeout_secs) = timeout_secs {
231            client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
232        }
233
234        let client = client_builder
235            .build()
236            .expect("Failed to build reqwest client");
237
238        let client = InnerHttpClient {
239            client,
240            header_keys: Arc::new(header_keys),
241        };
242
243        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
244
245        Self {
246            client,
247            rate_limiter,
248        }
249    }
250
251    /// Sends an HTTP request.
252    ///
253    /// - `method`: The [`Method`] to use (GET, POST, etc.).
254    /// - `url`: The target URL.
255    /// - `headers`: Additional headers for this request.
256    /// - `body`: Optional request body.
257    /// - `keys`: Rate-limit keys to control request frequency.
258    /// - `timeout_secs`: Optional request timeout in seconds.
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if unable to send request or times out.
263    ///
264    /// # Examples
265    ///
266    /// If requesting `/foo/bar`, pass rate-limit keys `["foo/bar", "foo"]`.
267    #[allow(clippy::too_many_arguments)]
268    pub async fn request(
269        &self,
270        method: Method,
271        url: String,
272        headers: Option<HashMap<String, String>>,
273        body: Option<Vec<u8>>,
274        timeout_secs: Option<u64>,
275        keys: Option<Vec<String>>,
276    ) -> Result<HttpResponse, HttpClientError> {
277        let rate_limiter = self.rate_limiter.clone();
278        rate_limiter.await_keys_ready(keys).await;
279
280        self.client
281            .send_request(method, url, headers, body, timeout_secs)
282            .await
283    }
284}
285
286/// Internal implementation backing [`HttpClient`].
287///
288/// The client is backed by a [`reqwest::Client`] which keeps connections alive and
289/// can be cloned cheaply. The client also has a list of header fields to
290/// extract from the response.
291///
292/// The client returns an [`HttpResponse`]. The client filters only the key value
293/// for the give `header_keys`.
294#[derive(Clone, Debug)]
295pub struct InnerHttpClient {
296    pub(crate) client: reqwest::Client,
297    pub(crate) header_keys: Arc<Vec<String>>,
298}
299
300impl InnerHttpClient {
301    /// Sends an HTTP request and returns an [`HttpResponse`].
302    ///
303    /// - `method`: The HTTP method (e.g. GET, POST).
304    /// - `url`: The target URL.
305    /// - `headers`: Extra headers to send.
306    /// - `body`: Optional request body.
307    /// - `timeout_secs`: Optional request timeout in seconds.
308    ///
309    /// # Errors
310    ///
311    /// Returns an error if unable to send request or times out.
312    pub async fn send_request(
313        &self,
314        method: Method,
315        url: String,
316        headers: Option<HashMap<String, String>>,
317        body: Option<Vec<u8>>,
318        timeout_secs: Option<u64>,
319    ) -> Result<HttpResponse, HttpClientError> {
320        let headers = headers.unwrap_or_default();
321        let reqwest_url = Url::parse(url.as_str())
322            .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
323
324        let mut header_map = HeaderMap::new();
325        for (header_key, header_value) in &headers {
326            let key = HeaderName::from_bytes(header_key.as_bytes())
327                .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
328            let _ = header_map.insert(
329                key,
330                header_value
331                    .parse()
332                    .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
333            );
334        }
335
336        let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
337
338        if let Some(timeout_secs) = timeout_secs {
339            request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
340        }
341
342        let request = match body {
343            Some(b) => request_builder
344                .body(b)
345                .build()
346                .map_err(HttpClientError::from)?,
347            None => request_builder.build().map_err(HttpClientError::from)?,
348        };
349
350        tracing::trace!("{request:?}");
351
352        let response = self
353            .client
354            .execute(request)
355            .await
356            .map_err(HttpClientError::from)?;
357
358        self.to_response(response).await
359    }
360
361    /// Converts a `reqwest::Response` into an `HttpResponse`.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if unable to send request or times out.
366    pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
367        tracing::trace!("{response:?}");
368
369        let headers: HashMap<String, String> = self
370            .header_keys
371            .iter()
372            .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
373            .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
374            .map(|(k, v)| (k.clone(), v.to_owned()))
375            .collect();
376        let status = HttpStatus::new(response.status());
377        let body = response.bytes().await.map_err(HttpClientError::from)?;
378
379        Ok(HttpResponse {
380            status,
381            headers,
382            body,
383        })
384    }
385}
386
387impl Default for InnerHttpClient {
388    /// Creates a new default [`InnerHttpClient`] instance.
389    ///
390    /// The default client is initialized with an empty list of header keys and a new `reqwest::Client`.
391    fn default() -> Self {
392        let client = reqwest::Client::new();
393        Self {
394            client,
395            header_keys: Default::default(),
396        }
397    }
398}
399
400////////////////////////////////////////////////////////////////////////////////
401// Tests
402////////////////////////////////////////////////////////////////////////////////
403#[cfg(test)]
404#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
405mod tests {
406    use std::net::{SocketAddr, TcpListener};
407
408    use axum::{
409        Router,
410        routing::{delete, get, patch, post},
411        serve,
412    };
413    use http::status::StatusCode;
414
415    use super::*;
416
417    fn get_unique_port() -> u16 {
418        // Create a temporary TcpListener to get an available port
419        let listener =
420            TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
421        let port = listener.local_addr().unwrap().port();
422
423        // Close the listener to free up the port
424        drop(listener);
425
426        port
427    }
428
429    fn create_router() -> Router {
430        Router::new()
431            .route("/get", get(|| async { "hello-world!" }))
432            .route("/post", post(|| async { StatusCode::OK }))
433            .route("/patch", patch(|| async { StatusCode::OK }))
434            .route("/delete", delete(|| async { StatusCode::OK }))
435            .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
436            .route(
437                "/slow",
438                get(|| async {
439                    tokio::time::sleep(Duration::from_secs(2)).await;
440                    "Eventually responded"
441                }),
442            )
443    }
444
445    async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
446        let port = get_unique_port();
447        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
448            .await
449            .unwrap();
450        let addr = listener.local_addr().unwrap();
451
452        tokio::spawn(async move {
453            serve(listener, create_router()).await.unwrap();
454        });
455
456        Ok(addr)
457    }
458
459    #[tokio::test]
460    async fn test_get() {
461        let addr = start_test_server().await.unwrap();
462        let url = format!("http://{addr}");
463
464        let client = InnerHttpClient::default();
465        let response = client
466            .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
467            .await
468            .unwrap();
469
470        assert!(response.status.is_success());
471        assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
472    }
473
474    #[tokio::test]
475    async fn test_post() {
476        let addr = start_test_server().await.unwrap();
477        let url = format!("http://{addr}");
478
479        let client = InnerHttpClient::default();
480        let response = client
481            .send_request(
482                reqwest::Method::POST,
483                format!("{url}/post"),
484                None,
485                None,
486                None,
487            )
488            .await
489            .unwrap();
490
491        assert!(response.status.is_success());
492    }
493
494    #[tokio::test]
495    async fn test_post_with_body() {
496        let addr = start_test_server().await.unwrap();
497        let url = format!("http://{addr}");
498
499        let client = InnerHttpClient::default();
500
501        let mut body = HashMap::new();
502        body.insert(
503            "key1".to_string(),
504            serde_json::Value::String("value1".to_string()),
505        );
506        body.insert(
507            "key2".to_string(),
508            serde_json::Value::String("value2".to_string()),
509        );
510
511        let body_string = serde_json::to_string(&body).unwrap();
512        let body_bytes = body_string.into_bytes();
513
514        let response = client
515            .send_request(
516                reqwest::Method::POST,
517                format!("{url}/post"),
518                None,
519                Some(body_bytes),
520                None,
521            )
522            .await
523            .unwrap();
524
525        assert!(response.status.is_success());
526    }
527
528    #[tokio::test]
529    async fn test_patch() {
530        let addr = start_test_server().await.unwrap();
531        let url = format!("http://{addr}");
532
533        let client = InnerHttpClient::default();
534        let response = client
535            .send_request(
536                reqwest::Method::PATCH,
537                format!("{url}/patch"),
538                None,
539                None,
540                None,
541            )
542            .await
543            .unwrap();
544
545        assert!(response.status.is_success());
546    }
547
548    #[tokio::test]
549    async fn test_delete() {
550        let addr = start_test_server().await.unwrap();
551        let url = format!("http://{addr}");
552
553        let client = InnerHttpClient::default();
554        let response = client
555            .send_request(
556                reqwest::Method::DELETE,
557                format!("{url}/delete"),
558                None,
559                None,
560                None,
561            )
562            .await
563            .unwrap();
564
565        assert!(response.status.is_success());
566    }
567
568    #[tokio::test]
569    async fn test_not_found() {
570        let addr = start_test_server().await.unwrap();
571        let url = format!("http://{addr}/notfound");
572        let client = InnerHttpClient::default();
573
574        let response = client
575            .send_request(reqwest::Method::GET, url, None, None, None)
576            .await
577            .unwrap();
578
579        assert!(response.status.is_client_error());
580        assert_eq!(response.status.as_u16(), 404);
581    }
582
583    #[tokio::test]
584    async fn test_timeout() {
585        let addr = start_test_server().await.unwrap();
586        let url = format!("http://{addr}/slow");
587        let client = InnerHttpClient::default();
588
589        // We'll set a 1-second timeout for a route that sleeps 2 seconds
590        let result = client
591            .send_request(reqwest::Method::GET, url, None, None, Some(1))
592            .await;
593
594        match result {
595            Err(HttpClientError::TimeoutError(msg)) => {
596                println!("Got expected timeout error: {msg}");
597            }
598            Err(other) => panic!("Expected a timeout error, got: {other:?}"),
599            Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
600        }
601    }
602}