nautilus_infrastructure/python/sql/
cache.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use bytes::Bytes;
17use nautilus_common::{
18    cache::database::CacheDatabaseAdapter, custom::CustomData, live::get_runtime, signal::Signal,
19};
20use nautilus_core::python::to_pyruntime_err;
21use nautilus_model::{
22    data::{Bar, DataType, QuoteTick, TradeTick},
23    events::{OrderSnapshot, PositionSnapshot},
24    identifiers::{AccountId, ClientId, ClientOrderId, InstrumentId, PositionId},
25    python::{
26        account::{account_any_to_pyobject, pyobject_to_account_any},
27        events::order::pyobject_to_order_event,
28        instruments::{instrument_any_to_pyobject, pyobject_to_instrument_any},
29        orders::{order_any_to_pyobject, pyobject_to_order_any},
30    },
31    types::Currency,
32};
33use pyo3::{IntoPyObjectExt, prelude::*};
34
35use crate::sql::{cache::PostgresCacheDatabase, queries::DatabaseQueries};
36
37#[pymethods]
38impl PostgresCacheDatabase {
39    #[staticmethod]
40    #[pyo3(name = "connect")]
41    #[pyo3(signature = (host=None, port=None, username=None, password=None, database=None))]
42    fn py_connect(
43        host: Option<String>,
44        port: Option<u16>,
45        username: Option<String>,
46        password: Option<String>,
47        database: Option<String>,
48    ) -> PyResult<Self> {
49        let result = get_runtime()
50            .block_on(async { Self::connect(host, port, username, password, database).await });
51        result.map_err(to_pyruntime_err)
52    }
53
54    #[pyo3(name = "close")]
55    fn py_close(&mut self) -> PyResult<()> {
56        self.close().map_err(to_pyruntime_err)
57    }
58
59    #[pyo3(name = "flush_db")]
60    fn py_flush_db(&mut self) -> PyResult<()> {
61        self.flush().map_err(to_pyruntime_err)
62    }
63
64    #[pyo3(name = "load")]
65    fn py_load(&self) -> PyResult<std::collections::HashMap<String, Vec<u8>>> {
66        get_runtime()
67            .block_on(async { DatabaseQueries::load(&self.pool).await })
68            .map(|m| m.into_iter().collect())
69            .map_err(to_pyruntime_err)
70    }
71
72    #[pyo3(name = "load_currency")]
73    fn py_load_currency(&self, code: &str) -> PyResult<Option<Currency>> {
74        let result = get_runtime()
75            .block_on(async { DatabaseQueries::load_currency(&self.pool, code).await });
76        result.map_err(to_pyruntime_err)
77    }
78
79    #[pyo3(name = "load_currencies")]
80    fn py_load_currencies(&self) -> PyResult<Vec<Currency>> {
81        let result =
82            get_runtime().block_on(async { DatabaseQueries::load_currencies(&self.pool).await });
83        result.map_err(to_pyruntime_err)
84    }
85
86    #[pyo3(name = "load_instrument")]
87    fn py_load_instrument(
88        &self,
89        py: Python,
90        instrument_id: InstrumentId,
91    ) -> PyResult<Option<Py<PyAny>>> {
92        get_runtime().block_on(async {
93            let result = DatabaseQueries::load_instrument(&self.pool, &instrument_id)
94                .await
95                .unwrap();
96            match result {
97                Some(instrument) => {
98                    let py_object = instrument_any_to_pyobject(py, instrument)?;
99                    Ok(Some(py_object))
100                }
101                None => Ok(None),
102            }
103        })
104    }
105
106    #[pyo3(name = "load_instruments")]
107    fn py_load_instruments(&self, py: Python) -> PyResult<Vec<Py<PyAny>>> {
108        get_runtime().block_on(async {
109            let result = DatabaseQueries::load_instruments(&self.pool).await.unwrap();
110            let mut instruments = Vec::new();
111            for instrument in result {
112                let py_object = instrument_any_to_pyobject(py, instrument)?;
113                instruments.push(py_object);
114            }
115            Ok(instruments)
116        })
117    }
118
119    #[pyo3(name = "load_order")]
120    fn py_load_order(
121        &self,
122        py: Python,
123        client_order_id: ClientOrderId,
124    ) -> PyResult<Option<Py<PyAny>>> {
125        get_runtime().block_on(async {
126            let result = DatabaseQueries::load_order(&self.pool, &client_order_id)
127                .await
128                .unwrap();
129            match result {
130                Some(order) => {
131                    let py_object = order_any_to_pyobject(py, order)?;
132                    Ok(Some(py_object))
133                }
134                None => Ok(None),
135            }
136        })
137    }
138
139    #[pyo3(name = "load_account")]
140    fn py_load_account(&self, py: Python, account_id: AccountId) -> PyResult<Option<Py<PyAny>>> {
141        get_runtime().block_on(async {
142            let result = DatabaseQueries::load_account(&self.pool, &account_id)
143                .await
144                .unwrap();
145            match result {
146                Some(account) => {
147                    let py_object = account_any_to_pyobject(py, account)?;
148                    Ok(Some(py_object))
149                }
150                None => Ok(None),
151            }
152        })
153    }
154
155    #[pyo3(name = "load_quotes")]
156    fn py_load_quotes(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
157        get_runtime().block_on(async {
158            let result = DatabaseQueries::load_quotes(&self.pool, &instrument_id)
159                .await
160                .unwrap();
161            let mut quotes = Vec::new();
162            for quote in result {
163                let py_object = quote.into_py_any(py)?;
164                quotes.push(py_object);
165            }
166            Ok(quotes)
167        })
168    }
169
170    #[pyo3(name = "load_trades")]
171    fn py_load_trades(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
172        get_runtime().block_on(async {
173            let result = DatabaseQueries::load_trades(&self.pool, &instrument_id)
174                .await
175                .unwrap();
176            let mut trades = Vec::new();
177            for trade in result {
178                let py_object = trade.into_py_any(py)?;
179                trades.push(py_object);
180            }
181            Ok(trades)
182        })
183    }
184
185    #[pyo3(name = "load_bars")]
186    fn py_load_bars(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
187        get_runtime().block_on(async {
188            let result = DatabaseQueries::load_bars(&self.pool, &instrument_id)
189                .await
190                .unwrap();
191            let mut bars = Vec::new();
192            for bar in result {
193                let py_object = bar.into_py_any(py)?;
194                bars.push(py_object);
195            }
196            Ok(bars)
197        })
198    }
199
200    #[pyo3(name = "load_signals")]
201    fn py_load_signals(&self, name: &str) -> PyResult<Vec<Signal>> {
202        get_runtime().block_on(async {
203            DatabaseQueries::load_signals(&self.pool, name)
204                .await
205                .map_err(to_pyruntime_err)
206        })
207    }
208
209    #[pyo3(name = "load_custom_data")]
210    fn py_load_custom_data(&self, data_type: DataType) -> PyResult<Vec<CustomData>> {
211        get_runtime().block_on(async {
212            DatabaseQueries::load_custom_data(&self.pool, &data_type)
213                .await
214                .map_err(to_pyruntime_err)
215        })
216    }
217
218    #[pyo3(name = "load_order_snapshot")]
219    fn py_load_order_snapshot(
220        &self,
221        client_order_id: ClientOrderId,
222    ) -> PyResult<Option<OrderSnapshot>> {
223        get_runtime().block_on(async {
224            DatabaseQueries::load_order_snapshot(&self.pool, &client_order_id)
225                .await
226                .map_err(to_pyruntime_err)
227        })
228    }
229
230    #[pyo3(name = "load_position_snapshot")]
231    fn py_load_position_snapshot(
232        &self,
233        position_id: PositionId,
234    ) -> PyResult<Option<PositionSnapshot>> {
235        get_runtime().block_on(async {
236            DatabaseQueries::load_position_snapshot(&self.pool, &position_id)
237                .await
238                .map_err(to_pyruntime_err)
239        })
240    }
241
242    #[pyo3(name = "add")]
243    fn py_add(&self, key: String, value: Vec<u8>) -> PyResult<()> {
244        self.add(key, Bytes::from(value)).map_err(to_pyruntime_err)
245    }
246
247    #[pyo3(name = "add_currency")]
248    fn py_add_currency(&self, currency: Currency) -> PyResult<()> {
249        self.add_currency(&currency).map_err(to_pyruntime_err)
250    }
251
252    #[pyo3(name = "add_instrument")]
253    fn py_add_instrument(&self, py: Python, instrument: Py<PyAny>) -> PyResult<()> {
254        let instrument_any = pyobject_to_instrument_any(py, instrument)?;
255        self.add_instrument(&instrument_any)
256            .map_err(to_pyruntime_err)
257    }
258
259    #[pyo3(name = "add_order")]
260    #[pyo3(signature = (order, client_id=None))]
261    fn py_add_order(
262        &self,
263        py: Python,
264        order: Py<PyAny>,
265        client_id: Option<ClientId>,
266    ) -> PyResult<()> {
267        let order_any = pyobject_to_order_any(py, order)?;
268        self.add_order(&order_any, client_id)
269            .map_err(to_pyruntime_err)
270    }
271
272    #[pyo3(name = "add_order_snapshot")]
273    fn py_add_order_snapshot(&self, snapshot: OrderSnapshot) -> PyResult<()> {
274        self.add_order_snapshot(&snapshot).map_err(to_pyruntime_err)
275    }
276
277    #[pyo3(name = "add_position_snapshot")]
278    fn py_add_position_snapshot(&self, snapshot: PositionSnapshot) -> PyResult<()> {
279        self.add_position_snapshot(&snapshot)
280            .map_err(to_pyruntime_err)
281    }
282
283    #[pyo3(name = "add_account")]
284    fn py_add_account(&self, py: Python, account: Py<PyAny>) -> PyResult<()> {
285        let account_any = pyobject_to_account_any(py, account)?;
286        self.add_account(&account_any).map_err(to_pyruntime_err)
287    }
288
289    #[pyo3(name = "add_quote")]
290    fn py_add_quote(&self, quote: QuoteTick) -> PyResult<()> {
291        self.add_quote(&quote).map_err(to_pyruntime_err)
292    }
293
294    #[pyo3(name = "add_trade")]
295    fn py_add_trade(&self, trade: TradeTick) -> PyResult<()> {
296        self.add_trade(&trade).map_err(to_pyruntime_err)
297    }
298
299    #[pyo3(name = "add_bar")]
300    fn py_add_bar(&self, bar: Bar) -> PyResult<()> {
301        self.add_bar(&bar).map_err(to_pyruntime_err)
302    }
303
304    #[pyo3(name = "add_signal")]
305    fn py_add_signal(&self, signal: Signal) -> PyResult<()> {
306        self.add_signal(&signal).map_err(to_pyruntime_err)
307    }
308
309    #[pyo3(name = "add_custom_data")]
310    fn py_add_custom_data(&self, data: CustomData) -> PyResult<()> {
311        self.add_custom_data(&data).map_err(to_pyruntime_err)
312    }
313
314    #[pyo3(name = "update_order")]
315    fn py_update_order(&self, py: Python, order_event: Py<PyAny>) -> PyResult<()> {
316        let event = pyobject_to_order_event(py, order_event)?;
317        self.update_order(&event).map_err(to_pyruntime_err)
318    }
319
320    #[pyo3(name = "update_account")]
321    fn py_update_account(&self, py: Python, order: Py<PyAny>) -> PyResult<()> {
322        let order_any = pyobject_to_account_any(py, order)?;
323        self.update_account(&order_any).map_err(to_pyruntime_err)
324    }
325}