nautilus_network/python/
http.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
// -------------------------------------------------------------------------------------------------
//  Copyright (C) 2015-2024 Nautech Systems Pty Ltd. All rights reserved.
//  https://nautechsystems.io
//
//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
//  You may not use this file except in compliance with the License.
//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::{
    collections::{hash_map::DefaultHasher, HashMap},
    hash::{Hash, Hasher},
    sync::Arc,
};

use bytes::Bytes;
use futures_util::{stream, StreamExt};
use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyBytes};

use crate::{
    http::{HttpClient, HttpClientError, HttpMethod, HttpResponse, InnerHttpClient},
    ratelimiter::{quota::Quota, RateLimiter},
};

// Python exception class for generic HTTP errors.
create_exception!(network, HttpError, PyException);

// Python exception class for generic HTTP timeout errors.
create_exception!(network, HttpTimeoutError, PyException);

impl HttpClientError {
    #[must_use]
    pub fn into_py_err(self) -> PyErr {
        match self {
            Self::Error(e) => PyErr::new::<HttpError, _>(e),
            Self::TimeoutError(e) => PyErr::new::<HttpTimeoutError, _>(e),
        }
    }
}

#[pymethods]
impl HttpMethod {
    fn __hash__(&self) -> isize {
        let mut h = DefaultHasher::new();
        self.hash(&mut h);
        h.finish() as isize
    }
}

#[pymethods]
impl HttpResponse {
    #[new]
    #[must_use]
    pub fn py_new(status: u16, body: Vec<u8>) -> Self {
        Self {
            status,
            headers: HashMap::new(),
            body: Bytes::from(body),
        }
    }

    #[getter]
    #[pyo3(name = "status")]
    pub const fn py_status(&self) -> u16 {
        self.status
    }

    #[getter]
    #[pyo3(name = "headers")]
    pub fn py_headers(&self) -> HashMap<String, String> {
        self.headers.clone()
    }

    #[getter]
    #[pyo3(name = "body")]
    pub fn py_body(&self) -> &[u8] {
        self.body.as_ref()
    }
}

#[pymethods]
impl HttpClient {
    /// Create a new HttpClient.
    ///
    /// `header_keys`: The key value pairs for the given `header_keys` are retained from the responses.
    /// `keyed_quota`: A list of string quota pairs that gives quota for specific key values.
    /// `default_quota`: The default rate limiting quota for any request.
    /// Default quota is optional and no quota is passthrough.
    ///
    /// Rate limiting can be configured on a per-endpoint basis by passing
    /// key-value pairs of endpoint URLs and their respective quotas.
    ///
    /// For /foo -> 10 reqs/sec configure limit with ("foo", Quota.rate_per_second(10))
    ///
    /// Hierarchical rate limiting can be achieved by configuring the quotas for
    /// each level.
    ///
    /// For /foo/bar -> 10 reqs/sec and /foo -> 20 reqs/sec configure limits for
    /// keys "foo/bar" and "foo" respectively.
    ///
    /// When a request is made the URL should be split into all the keys within it.
    ///
    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
    #[new]
    #[pyo3(signature = (header_keys = Vec::new(), keyed_quotas = Vec::new(), default_quota = None))]
    #[must_use]
    pub fn py_new(
        header_keys: Vec<String>,
        keyed_quotas: Vec<(String, Quota)>,
        default_quota: Option<Quota>,
    ) -> Self {
        let client = reqwest::Client::new();
        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));

        let client = InnerHttpClient {
            client,
            header_keys,
        };

        Self {
            rate_limiter,
            client,
        }
    }

    /// Send an HTTP request.
    ///
    /// `method`: The HTTP method to call.
    /// `url`: The request is sent to this url.
    /// `headers`: The header key value pairs in the request.
    /// `body`: The bytes sent in the body of request.
    /// `keys`: The keys used for rate limiting the request.
    ///
    /// # Example
    ///
    /// When a request is made the URL should be split into all relevant keys within it.
    ///
    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
    #[pyo3(name = "request")]
    #[allow(clippy::too_many_arguments)]
    #[pyo3(signature = (method, url, headers=None, body=None, keys=None, timeout_secs=None))]
    fn py_request<'py>(
        &self,
        method: HttpMethod,
        url: String,
        headers: Option<HashMap<String, String>>,
        body: Option<Bound<'py, PyBytes>>,
        keys: Option<Vec<String>>,
        timeout_secs: Option<u64>,
        py: Python<'py>,
    ) -> PyResult<Bound<'py, PyAny>> {
        let headers = headers.unwrap_or_default();
        let body_vec = body.map(|py_bytes| py_bytes.as_bytes().to_vec());
        let keys = keys.unwrap_or_default();
        let client = self.client.clone();
        let rate_limiter = self.rate_limiter.clone();
        let method = method.into();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            // Check keys for rate limiting quota
            let tasks = keys.iter().map(|key| rate_limiter.until_key_ready(key));
            stream::iter(tasks)
                .for_each(|key| async move {
                    key.await;
                })
                .await;
            client
                .send_request(method, url, headers, body_vec, timeout_secs)
                .await
                .map_err(super::super::http::HttpClientError::into_py_err)
        })
    }
}