nautilus_infrastructure/python/sql/
cache.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::collections::HashMap;
17
18use bytes::Bytes;
19use nautilus_common::{
20    cache::database::CacheDatabaseAdapter, custom::CustomData, runtime::get_runtime, signal::Signal,
21};
22use nautilus_core::python::to_pyruntime_err;
23use nautilus_model::{
24    data::{Bar, DataType, QuoteTick, TradeTick},
25    events::{OrderSnapshot, PositionSnapshot},
26    identifiers::{AccountId, ClientId, ClientOrderId, InstrumentId, PositionId},
27    python::{
28        account::{account_any_to_pyobject, pyobject_to_account_any},
29        events::order::pyobject_to_order_event,
30        instruments::{instrument_any_to_pyobject, pyobject_to_instrument_any},
31        orders::{order_any_to_pyobject, pyobject_to_order_any},
32    },
33    types::Currency,
34};
35use pyo3::{IntoPyObjectExt, prelude::*};
36
37use crate::sql::{cache::PostgresCacheDatabase, queries::DatabaseQueries};
38
39#[pymethods]
40impl PostgresCacheDatabase {
41    #[staticmethod]
42    #[pyo3(name = "connect")]
43    #[pyo3(signature = (host=None, port=None, username=None, password=None, database=None))]
44    fn py_connect(
45        host: Option<String>,
46        port: Option<u16>,
47        username: Option<String>,
48        password: Option<String>,
49        database: Option<String>,
50    ) -> PyResult<Self> {
51        let result = get_runtime().block_on(async {
52            PostgresCacheDatabase::connect(host, port, username, password, database).await
53        });
54        result.map_err(to_pyruntime_err)
55    }
56
57    #[pyo3(name = "close")]
58    fn py_close(&mut self) -> PyResult<()> {
59        self.close().map_err(to_pyruntime_err)
60    }
61
62    #[pyo3(name = "flush_db")]
63    fn py_flush_db(&mut self) -> PyResult<()> {
64        self.flush().map_err(to_pyruntime_err)
65    }
66
67    #[pyo3(name = "load")]
68    fn py_load(&self) -> PyResult<HashMap<String, Vec<u8>>> {
69        get_runtime()
70            .block_on(async { DatabaseQueries::load(&self.pool).await })
71            .map_err(to_pyruntime_err)
72    }
73
74    #[pyo3(name = "load_currency")]
75    fn py_load_currency(&self, code: &str) -> PyResult<Option<Currency>> {
76        let result = get_runtime()
77            .block_on(async { DatabaseQueries::load_currency(&self.pool, code).await });
78        result.map_err(to_pyruntime_err)
79    }
80
81    #[pyo3(name = "load_currencies")]
82    fn py_load_currencies(&self) -> PyResult<Vec<Currency>> {
83        let result =
84            get_runtime().block_on(async { DatabaseQueries::load_currencies(&self.pool).await });
85        result.map_err(to_pyruntime_err)
86    }
87
88    #[pyo3(name = "load_instrument")]
89    fn py_load_instrument(
90        &self,
91        py: Python,
92        instrument_id: InstrumentId,
93    ) -> PyResult<Option<PyObject>> {
94        get_runtime().block_on(async {
95            let result = DatabaseQueries::load_instrument(&self.pool, &instrument_id)
96                .await
97                .unwrap();
98            match result {
99                Some(instrument) => {
100                    let py_object = instrument_any_to_pyobject(py, instrument)?;
101                    Ok(Some(py_object))
102                }
103                None => Ok(None),
104            }
105        })
106    }
107
108    #[pyo3(name = "load_instruments")]
109    fn py_load_instruments(&self, py: Python) -> PyResult<Vec<PyObject>> {
110        get_runtime().block_on(async {
111            let result = DatabaseQueries::load_instruments(&self.pool).await.unwrap();
112            let mut instruments = Vec::new();
113            for instrument in result {
114                let py_object = instrument_any_to_pyobject(py, instrument)?;
115                instruments.push(py_object);
116            }
117            Ok(instruments)
118        })
119    }
120
121    #[pyo3(name = "load_order")]
122    fn py_load_order(
123        &self,
124        py: Python,
125        client_order_id: ClientOrderId,
126    ) -> PyResult<Option<PyObject>> {
127        get_runtime().block_on(async {
128            let result = DatabaseQueries::load_order(&self.pool, &client_order_id)
129                .await
130                .unwrap();
131            match result {
132                Some(order) => {
133                    let py_object = order_any_to_pyobject(py, order)?;
134                    Ok(Some(py_object))
135                }
136                None => Ok(None),
137            }
138        })
139    }
140
141    #[pyo3(name = "load_account")]
142    fn py_load_account(&self, py: Python, account_id: AccountId) -> PyResult<Option<PyObject>> {
143        get_runtime().block_on(async {
144            let result = DatabaseQueries::load_account(&self.pool, &account_id)
145                .await
146                .unwrap();
147            match result {
148                Some(account) => {
149                    let py_object = account_any_to_pyobject(py, account)?;
150                    Ok(Some(py_object))
151                }
152                None => Ok(None),
153            }
154        })
155    }
156
157    #[pyo3(name = "load_quotes")]
158    fn py_load_quotes(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
159        get_runtime().block_on(async {
160            let result = DatabaseQueries::load_quotes(&self.pool, &instrument_id)
161                .await
162                .unwrap();
163            let mut quotes = Vec::new();
164            for quote in result {
165                let py_object = quote.into_py_any(py)?;
166                quotes.push(py_object);
167            }
168            Ok(quotes)
169        })
170    }
171
172    #[pyo3(name = "load_trades")]
173    fn py_load_trades(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
174        get_runtime().block_on(async {
175            let result = DatabaseQueries::load_trades(&self.pool, &instrument_id)
176                .await
177                .unwrap();
178            let mut trades = Vec::new();
179            for trade in result {
180                let py_object = trade.into_py_any(py)?;
181                trades.push(py_object);
182            }
183            Ok(trades)
184        })
185    }
186
187    #[pyo3(name = "load_bars")]
188    fn py_load_bars(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
189        get_runtime().block_on(async {
190            let result = DatabaseQueries::load_bars(&self.pool, &instrument_id)
191                .await
192                .unwrap();
193            let mut bars = Vec::new();
194            for bar in result {
195                let py_object = bar.into_py_any(py)?;
196                bars.push(py_object);
197            }
198            Ok(bars)
199        })
200    }
201
202    #[pyo3(name = "load_signals")]
203    fn py_load_signals(&self, name: &str) -> PyResult<Vec<Signal>> {
204        get_runtime().block_on(async {
205            DatabaseQueries::load_signals(&self.pool, name)
206                .await
207                .map_err(to_pyruntime_err)
208        })
209    }
210
211    #[pyo3(name = "load_custom_data")]
212    fn py_load_custom_data(&self, data_type: DataType) -> PyResult<Vec<CustomData>> {
213        get_runtime().block_on(async {
214            DatabaseQueries::load_custom_data(&self.pool, &data_type)
215                .await
216                .map_err(to_pyruntime_err)
217        })
218    }
219
220    #[pyo3(name = "load_order_snapshot")]
221    fn py_load_order_snapshot(
222        &self,
223        client_order_id: ClientOrderId,
224    ) -> PyResult<Option<OrderSnapshot>> {
225        get_runtime().block_on(async {
226            DatabaseQueries::load_order_snapshot(&self.pool, &client_order_id)
227                .await
228                .map_err(to_pyruntime_err)
229        })
230    }
231
232    #[pyo3(name = "load_position_snapshot")]
233    fn py_load_position_snapshot(
234        &self,
235        position_id: PositionId,
236    ) -> PyResult<Option<PositionSnapshot>> {
237        get_runtime().block_on(async {
238            DatabaseQueries::load_position_snapshot(&self.pool, &position_id)
239                .await
240                .map_err(to_pyruntime_err)
241        })
242    }
243
244    #[pyo3(name = "add")]
245    fn py_add(&self, key: String, value: Vec<u8>) -> PyResult<()> {
246        self.add(key, Bytes::from(value)).map_err(to_pyruntime_err)
247    }
248
249    #[pyo3(name = "add_currency")]
250    fn py_add_currency(&self, currency: Currency) -> PyResult<()> {
251        self.add_currency(&currency).map_err(to_pyruntime_err)
252    }
253
254    #[pyo3(name = "add_instrument")]
255    fn py_add_instrument(&self, py: Python, instrument: PyObject) -> PyResult<()> {
256        let instrument_any = pyobject_to_instrument_any(py, instrument)?;
257        self.add_instrument(&instrument_any)
258            .map_err(to_pyruntime_err)
259    }
260
261    #[pyo3(name = "add_order")]
262    #[pyo3(signature = (order, client_id=None))]
263    fn py_add_order(
264        &self,
265        py: Python,
266        order: PyObject,
267        client_id: Option<ClientId>,
268    ) -> PyResult<()> {
269        let order_any = pyobject_to_order_any(py, order)?;
270        self.add_order(&order_any, client_id)
271            .map_err(to_pyruntime_err)
272    }
273
274    #[pyo3(name = "add_order_snapshot")]
275    fn py_add_order_snapshot(&self, snapshot: OrderSnapshot) -> PyResult<()> {
276        self.add_order_snapshot(&snapshot).map_err(to_pyruntime_err)
277    }
278
279    #[pyo3(name = "add_position_snapshot")]
280    fn py_add_position_snapshot(&self, snapshot: PositionSnapshot) -> PyResult<()> {
281        self.add_position_snapshot(&snapshot)
282            .map_err(to_pyruntime_err)
283    }
284
285    #[pyo3(name = "add_account")]
286    fn py_add_account(&self, py: Python, account: PyObject) -> PyResult<()> {
287        let account_any = pyobject_to_account_any(py, account)?;
288        self.add_account(&account_any).map_err(to_pyruntime_err)
289    }
290
291    #[pyo3(name = "add_quote")]
292    fn py_add_quote(&self, quote: QuoteTick) -> PyResult<()> {
293        self.add_quote(&quote).map_err(to_pyruntime_err)
294    }
295
296    #[pyo3(name = "add_trade")]
297    fn py_add_trade(&self, trade: TradeTick) -> PyResult<()> {
298        self.add_trade(&trade).map_err(to_pyruntime_err)
299    }
300
301    #[pyo3(name = "add_bar")]
302    fn py_add_bar(&self, bar: Bar) -> PyResult<()> {
303        self.add_bar(&bar).map_err(to_pyruntime_err)
304    }
305
306    #[pyo3(name = "add_signal")]
307    fn py_add_signal(&self, signal: Signal) -> PyResult<()> {
308        self.add_signal(&signal).map_err(to_pyruntime_err)
309    }
310
311    #[pyo3(name = "add_custom_data")]
312    fn py_add_custom_data(&self, data: CustomData) -> PyResult<()> {
313        self.add_custom_data(&data).map_err(to_pyruntime_err)
314    }
315
316    #[pyo3(name = "update_order")]
317    fn py_update_order(&self, py: Python, order_event: PyObject) -> PyResult<()> {
318        let event = pyobject_to_order_event(py, order_event)?;
319        self.update_order(&event).map_err(to_pyruntime_err)
320    }
321
322    #[pyo3(name = "update_account")]
323    fn py_update_account(&self, py: Python, order: PyObject) -> PyResult<()> {
324        let order_any = pyobject_to_account_any(py, order)?;
325        self.update_account(&order_any).map_err(to_pyruntime_err)
326    }
327}