nautilus_infrastructure/python/sql/
cache.rs1use std::collections::HashMap;
17
18use bytes::Bytes;
19use nautilus_common::{
20 cache::database::CacheDatabaseAdapter, custom::CustomData, runtime::get_runtime, signal::Signal,
21};
22use nautilus_core::python::to_pyruntime_err;
23use nautilus_model::{
24 data::{Bar, DataType, QuoteTick, TradeTick},
25 events::{OrderSnapshot, PositionSnapshot},
26 identifiers::{AccountId, ClientId, ClientOrderId, InstrumentId, PositionId},
27 python::{
28 account::{account_any_to_pyobject, pyobject_to_account_any},
29 events::order::pyobject_to_order_event,
30 instruments::{instrument_any_to_pyobject, pyobject_to_instrument_any},
31 orders::{order_any_to_pyobject, pyobject_to_order_any},
32 },
33 types::Currency,
34};
35use pyo3::{IntoPyObjectExt, prelude::*};
36
37use crate::sql::{cache::PostgresCacheDatabase, queries::DatabaseQueries};
38
39#[pymethods]
40impl PostgresCacheDatabase {
41 #[staticmethod]
42 #[pyo3(name = "connect")]
43 #[pyo3(signature = (host=None, port=None, username=None, password=None, database=None))]
44 fn py_connect(
45 host: Option<String>,
46 port: Option<u16>,
47 username: Option<String>,
48 password: Option<String>,
49 database: Option<String>,
50 ) -> PyResult<Self> {
51 let result = get_runtime().block_on(async {
52 PostgresCacheDatabase::connect(host, port, username, password, database).await
53 });
54 result.map_err(to_pyruntime_err)
55 }
56
57 #[pyo3(name = "close")]
58 fn py_close(&mut self) -> PyResult<()> {
59 self.close().map_err(to_pyruntime_err)
60 }
61
62 #[pyo3(name = "flush_db")]
63 fn py_flush_db(&mut self) -> PyResult<()> {
64 self.flush().map_err(to_pyruntime_err)
65 }
66
67 #[pyo3(name = "load")]
68 fn py_load(&self) -> PyResult<HashMap<String, Vec<u8>>> {
69 get_runtime()
70 .block_on(async { DatabaseQueries::load(&self.pool).await })
71 .map_err(to_pyruntime_err)
72 }
73
74 #[pyo3(name = "load_currency")]
75 fn py_load_currency(&self, code: &str) -> PyResult<Option<Currency>> {
76 let result = get_runtime()
77 .block_on(async { DatabaseQueries::load_currency(&self.pool, code).await });
78 result.map_err(to_pyruntime_err)
79 }
80
81 #[pyo3(name = "load_currencies")]
82 fn py_load_currencies(&self) -> PyResult<Vec<Currency>> {
83 let result =
84 get_runtime().block_on(async { DatabaseQueries::load_currencies(&self.pool).await });
85 result.map_err(to_pyruntime_err)
86 }
87
88 #[pyo3(name = "load_instrument")]
89 fn py_load_instrument(
90 &self,
91 py: Python,
92 instrument_id: InstrumentId,
93 ) -> PyResult<Option<PyObject>> {
94 get_runtime().block_on(async {
95 let result = DatabaseQueries::load_instrument(&self.pool, &instrument_id)
96 .await
97 .unwrap();
98 match result {
99 Some(instrument) => {
100 let py_object = instrument_any_to_pyobject(py, instrument)?;
101 Ok(Some(py_object))
102 }
103 None => Ok(None),
104 }
105 })
106 }
107
108 #[pyo3(name = "load_instruments")]
109 fn py_load_instruments(&self, py: Python) -> PyResult<Vec<PyObject>> {
110 get_runtime().block_on(async {
111 let result = DatabaseQueries::load_instruments(&self.pool).await.unwrap();
112 let mut instruments = Vec::new();
113 for instrument in result {
114 let py_object = instrument_any_to_pyobject(py, instrument)?;
115 instruments.push(py_object);
116 }
117 Ok(instruments)
118 })
119 }
120
121 #[pyo3(name = "load_order")]
122 fn py_load_order(
123 &self,
124 py: Python,
125 client_order_id: ClientOrderId,
126 ) -> PyResult<Option<PyObject>> {
127 get_runtime().block_on(async {
128 let result = DatabaseQueries::load_order(&self.pool, &client_order_id)
129 .await
130 .unwrap();
131 match result {
132 Some(order) => {
133 let py_object = order_any_to_pyobject(py, order)?;
134 Ok(Some(py_object))
135 }
136 None => Ok(None),
137 }
138 })
139 }
140
141 #[pyo3(name = "load_account")]
142 fn py_load_account(&self, py: Python, account_id: AccountId) -> PyResult<Option<PyObject>> {
143 get_runtime().block_on(async {
144 let result = DatabaseQueries::load_account(&self.pool, &account_id)
145 .await
146 .unwrap();
147 match result {
148 Some(account) => {
149 let py_object = account_any_to_pyobject(py, account)?;
150 Ok(Some(py_object))
151 }
152 None => Ok(None),
153 }
154 })
155 }
156
157 #[pyo3(name = "load_quotes")]
158 fn py_load_quotes(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
159 get_runtime().block_on(async {
160 let result = DatabaseQueries::load_quotes(&self.pool, &instrument_id)
161 .await
162 .unwrap();
163 let mut quotes = Vec::new();
164 for quote in result {
165 let py_object = quote.into_py_any(py)?;
166 quotes.push(py_object);
167 }
168 Ok(quotes)
169 })
170 }
171
172 #[pyo3(name = "load_trades")]
173 fn py_load_trades(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
174 get_runtime().block_on(async {
175 let result = DatabaseQueries::load_trades(&self.pool, &instrument_id)
176 .await
177 .unwrap();
178 let mut trades = Vec::new();
179 for trade in result {
180 let py_object = trade.into_py_any(py)?;
181 trades.push(py_object);
182 }
183 Ok(trades)
184 })
185 }
186
187 #[pyo3(name = "load_bars")]
188 fn py_load_bars(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
189 get_runtime().block_on(async {
190 let result = DatabaseQueries::load_bars(&self.pool, &instrument_id)
191 .await
192 .unwrap();
193 let mut bars = Vec::new();
194 for bar in result {
195 let py_object = bar.into_py_any(py)?;
196 bars.push(py_object);
197 }
198 Ok(bars)
199 })
200 }
201
202 #[pyo3(name = "load_signals")]
203 fn py_load_signals(&self, name: &str) -> PyResult<Vec<Signal>> {
204 get_runtime().block_on(async {
205 DatabaseQueries::load_signals(&self.pool, name)
206 .await
207 .map_err(to_pyruntime_err)
208 })
209 }
210
211 #[pyo3(name = "load_custom_data")]
212 fn py_load_custom_data(&self, data_type: DataType) -> PyResult<Vec<CustomData>> {
213 get_runtime().block_on(async {
214 DatabaseQueries::load_custom_data(&self.pool, &data_type)
215 .await
216 .map_err(to_pyruntime_err)
217 })
218 }
219
220 #[pyo3(name = "load_order_snapshot")]
221 fn py_load_order_snapshot(
222 &self,
223 client_order_id: ClientOrderId,
224 ) -> PyResult<Option<OrderSnapshot>> {
225 get_runtime().block_on(async {
226 DatabaseQueries::load_order_snapshot(&self.pool, &client_order_id)
227 .await
228 .map_err(to_pyruntime_err)
229 })
230 }
231
232 #[pyo3(name = "load_position_snapshot")]
233 fn py_load_position_snapshot(
234 &self,
235 position_id: PositionId,
236 ) -> PyResult<Option<PositionSnapshot>> {
237 get_runtime().block_on(async {
238 DatabaseQueries::load_position_snapshot(&self.pool, &position_id)
239 .await
240 .map_err(to_pyruntime_err)
241 })
242 }
243
244 #[pyo3(name = "add")]
245 fn py_add(&self, key: String, value: Vec<u8>) -> PyResult<()> {
246 self.add(key, Bytes::from(value)).map_err(to_pyruntime_err)
247 }
248
249 #[pyo3(name = "add_currency")]
250 fn py_add_currency(&self, currency: Currency) -> PyResult<()> {
251 self.add_currency(¤cy).map_err(to_pyruntime_err)
252 }
253
254 #[pyo3(name = "add_instrument")]
255 fn py_add_instrument(&self, py: Python, instrument: PyObject) -> PyResult<()> {
256 let instrument_any = pyobject_to_instrument_any(py, instrument)?;
257 self.add_instrument(&instrument_any)
258 .map_err(to_pyruntime_err)
259 }
260
261 #[pyo3(name = "add_order")]
262 #[pyo3(signature = (order, client_id=None))]
263 fn py_add_order(
264 &self,
265 py: Python,
266 order: PyObject,
267 client_id: Option<ClientId>,
268 ) -> PyResult<()> {
269 let order_any = pyobject_to_order_any(py, order)?;
270 self.add_order(&order_any, client_id)
271 .map_err(to_pyruntime_err)
272 }
273
274 #[pyo3(name = "add_order_snapshot")]
275 fn py_add_order_snapshot(&self, snapshot: OrderSnapshot) -> PyResult<()> {
276 self.add_order_snapshot(&snapshot).map_err(to_pyruntime_err)
277 }
278
279 #[pyo3(name = "add_position_snapshot")]
280 fn py_add_position_snapshot(&self, snapshot: PositionSnapshot) -> PyResult<()> {
281 self.add_position_snapshot(&snapshot)
282 .map_err(to_pyruntime_err)
283 }
284
285 #[pyo3(name = "add_account")]
286 fn py_add_account(&self, py: Python, account: PyObject) -> PyResult<()> {
287 let account_any = pyobject_to_account_any(py, account)?;
288 self.add_account(&account_any).map_err(to_pyruntime_err)
289 }
290
291 #[pyo3(name = "add_quote")]
292 fn py_add_quote(&self, quote: QuoteTick) -> PyResult<()> {
293 self.add_quote("e).map_err(to_pyruntime_err)
294 }
295
296 #[pyo3(name = "add_trade")]
297 fn py_add_trade(&self, trade: TradeTick) -> PyResult<()> {
298 self.add_trade(&trade).map_err(to_pyruntime_err)
299 }
300
301 #[pyo3(name = "add_bar")]
302 fn py_add_bar(&self, bar: Bar) -> PyResult<()> {
303 self.add_bar(&bar).map_err(to_pyruntime_err)
304 }
305
306 #[pyo3(name = "add_signal")]
307 fn py_add_signal(&self, signal: Signal) -> PyResult<()> {
308 self.add_signal(&signal).map_err(to_pyruntime_err)
309 }
310
311 #[pyo3(name = "add_custom_data")]
312 fn py_add_custom_data(&self, data: CustomData) -> PyResult<()> {
313 self.add_custom_data(&data).map_err(to_pyruntime_err)
314 }
315
316 #[pyo3(name = "update_order")]
317 fn py_update_order(&self, py: Python, order_event: PyObject) -> PyResult<()> {
318 let event = pyobject_to_order_event(py, order_event)?;
319 self.update_order(&event).map_err(to_pyruntime_err)
320 }
321
322 #[pyo3(name = "update_account")]
323 fn py_update_account(&self, py: Python, order: PyObject) -> PyResult<()> {
324 let order_any = pyobject_to_account_any(py, order)?;
325 self.update_account(&order_any).map_err(to_pyruntime_err)
326 }
327}