1use std::{
17 collections::{HashMap, hash_map::DefaultHasher},
18 hash::{Hash, Hasher},
19 str::FromStr,
20};
21
22use nautilus_core::{
23 python::{
24 IntoPyObjectNautilusExt,
25 serialization::{from_dict_pyo3, to_dict_pyo3},
26 to_pyvalue_err,
27 },
28 serialization::{
29 Serializable,
30 msgpack::{FromMsgPack, ToMsgPack},
31 },
32};
33use pyo3::{prelude::*, pyclass::CompareOp, types::PyDict};
34
35use super::data_to_pycapsule;
36use crate::{
37 data::{
38 Data,
39 bar::{Bar, BarSpecification, BarType},
40 },
41 enums::{AggregationSource, BarAggregation, PriceType},
42 identifiers::InstrumentId,
43 python::common::PY_MODULE_MODEL,
44 types::{
45 price::{Price, PriceRaw},
46 quantity::{Quantity, QuantityRaw},
47 },
48};
49
50#[pymethods]
51impl BarSpecification {
52 #[new]
53 fn py_new(step: usize, aggregation: BarAggregation, price_type: PriceType) -> PyResult<Self> {
54 Self::new_checked(step, aggregation, price_type).map_err(to_pyvalue_err)
55 }
56
57 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
58 match op {
59 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
60 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
61 _ => py.NotImplemented(),
62 }
63 }
64
65 fn __hash__(&self) -> isize {
66 let mut h = DefaultHasher::new();
67 self.hash(&mut h);
68 h.finish() as isize
69 }
70
71 fn __repr__(&self) -> String {
72 format!("{self:?}")
73 }
74
75 fn __str__(&self) -> String {
76 self.to_string()
77 }
78
79 #[staticmethod]
80 #[pyo3(name = "fully_qualified_name")]
81 fn py_fully_qualified_name() -> String {
82 format!("{}:{}", PY_MODULE_MODEL, stringify!(BarSpecification))
83 }
84
85 #[getter]
86 #[pyo3(name = "timedelta")]
87 fn py_timedelta(&self) -> PyResult<chrono::TimeDelta> {
88 match self.aggregation {
89 BarAggregation::Millisecond
90 | BarAggregation::Second
91 | BarAggregation::Minute
92 | BarAggregation::Hour
93 | BarAggregation::Day => Ok(self.timedelta()),
94 _ => Err(to_pyvalue_err(format!(
95 "Timedelta not supported for aggregation type: {:?}",
96 self.aggregation
97 ))),
98 }
99 }
100}
101
102#[pymethods]
103impl BarType {
104 #[new]
105 #[pyo3(signature = (instrument_id, spec, aggregation_source = AggregationSource::External)
106 )]
107 fn py_new(
108 instrument_id: InstrumentId,
109 spec: BarSpecification,
110 aggregation_source: AggregationSource,
111 ) -> Self {
112 Self::new(instrument_id, spec, aggregation_source)
113 }
114
115 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
116 match op {
117 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
118 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
119 _ => py.NotImplemented(),
120 }
121 }
122
123 fn __hash__(&self) -> isize {
124 let mut h = DefaultHasher::new();
125 self.hash(&mut h);
126 h.finish() as isize
127 }
128
129 fn __repr__(&self) -> String {
130 format!("{self:?}")
131 }
132
133 fn __str__(&self) -> String {
134 self.to_string()
135 }
136
137 #[staticmethod]
138 #[pyo3(name = "fully_qualified_name")]
139 fn py_fully_qualified_name() -> String {
140 format!("{}:{}", PY_MODULE_MODEL, stringify!(BarType))
141 }
142
143 #[staticmethod]
144 #[pyo3(name = "from_str")]
145 fn py_from_str(value: &str) -> PyResult<Self> {
146 Self::from_str(value).map_err(to_pyvalue_err)
147 }
148
149 #[staticmethod]
150 #[pyo3(name = "new_composite")]
151 fn py_new_composite(
152 instrument_id: InstrumentId,
153 spec: BarSpecification,
154 aggregation_source: AggregationSource,
155 composite_step: usize,
156 composite_aggregation: BarAggregation,
157 composite_aggregation_source: AggregationSource,
158 ) -> Self {
159 Self::new_composite(
160 instrument_id,
161 spec,
162 aggregation_source,
163 composite_step,
164 composite_aggregation,
165 composite_aggregation_source,
166 )
167 }
168
169 #[pyo3(name = "is_standard")]
170 fn py_is_standard(&self) -> bool {
171 self.is_standard()
172 }
173
174 #[pyo3(name = "is_composite")]
175 fn py_is_composite(&self) -> bool {
176 self.is_composite()
177 }
178
179 #[pyo3(name = "standard")]
180 fn py_standard(&self) -> Self {
181 self.standard()
182 }
183
184 #[pyo3(name = "composite")]
185 fn py_composite(&self) -> Self {
186 self.composite()
187 }
188
189 #[pyo3(name = "id_spec_key")]
190 fn py_id_spec_key(&self) -> (InstrumentId, BarSpecification) {
191 self.id_spec_key()
192 }
193}
194
195impl Bar {
196 pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
202 let bar_type_obj: Bound<'_, PyAny> = obj.getattr("bar_type")?.extract()?;
203 let bar_type_str: String = bar_type_obj.call_method0("__str__")?.extract()?;
204 let bar_type = BarType::from(bar_type_str.as_str());
205
206 let open_py: Bound<'_, PyAny> = obj.getattr("open")?;
207 let price_prec: u8 = open_py.getattr("precision")?.extract()?;
208 let open_raw: PriceRaw = open_py.getattr("raw")?.extract()?;
209 let open = Price::from_raw(open_raw, price_prec);
210
211 let high_py: Bound<'_, PyAny> = obj.getattr("high")?;
212 let high_raw: PriceRaw = high_py.getattr("raw")?.extract()?;
213 let high = Price::from_raw(high_raw, price_prec);
214
215 let low_py: Bound<'_, PyAny> = obj.getattr("low")?;
216 let low_raw: PriceRaw = low_py.getattr("raw")?.extract()?;
217 let low = Price::from_raw(low_raw, price_prec);
218
219 let close_py: Bound<'_, PyAny> = obj.getattr("close")?;
220 let close_raw: PriceRaw = close_py.getattr("raw")?.extract()?;
221 let close = Price::from_raw(close_raw, price_prec);
222
223 let volume_py: Bound<'_, PyAny> = obj.getattr("volume")?;
224 let volume_raw: QuantityRaw = volume_py.getattr("raw")?.extract()?;
225 let volume_prec: u8 = volume_py.getattr("precision")?.extract()?;
226 let volume = Quantity::from_raw(volume_raw, volume_prec);
227
228 let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
229 let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
230
231 Ok(Self::new(
232 bar_type,
233 open,
234 high,
235 low,
236 close,
237 volume,
238 ts_event.into(),
239 ts_init.into(),
240 ))
241 }
242}
243
244#[pymethods]
245#[allow(clippy::too_many_arguments)]
246impl Bar {
247 #[new]
248 fn py_new(
249 bar_type: BarType,
250 open: Price,
251 high: Price,
252 low: Price,
253 close: Price,
254 volume: Quantity,
255 ts_event: u64,
256 ts_init: u64,
257 ) -> PyResult<Self> {
258 Self::new_checked(
259 bar_type,
260 open,
261 high,
262 low,
263 close,
264 volume,
265 ts_event.into(),
266 ts_init.into(),
267 )
268 .map_err(to_pyvalue_err)
269 }
270
271 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
272 match op {
273 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
274 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
275 _ => py.NotImplemented(),
276 }
277 }
278
279 fn __hash__(&self) -> isize {
280 let mut h = DefaultHasher::new();
281 self.hash(&mut h);
282 h.finish() as isize
283 }
284
285 fn __repr__(&self) -> String {
286 format!("{self:?}")
287 }
288
289 fn __str__(&self) -> String {
290 self.to_string()
291 }
292
293 #[getter]
294 #[pyo3(name = "bar_type")]
295 fn py_bar_type(&self) -> BarType {
296 self.bar_type
297 }
298
299 #[getter]
300 #[pyo3(name = "open")]
301 fn py_open(&self) -> Price {
302 self.open
303 }
304
305 #[getter]
306 #[pyo3(name = "high")]
307 fn py_high(&self) -> Price {
308 self.high
309 }
310
311 #[getter]
312 #[pyo3(name = "low")]
313 fn py_low(&self) -> Price {
314 self.low
315 }
316
317 #[getter]
318 #[pyo3(name = "close")]
319 fn py_close(&self) -> Price {
320 self.close
321 }
322
323 #[getter]
324 #[pyo3(name = "volume")]
325 fn py_volume(&self) -> Quantity {
326 self.volume
327 }
328
329 #[getter]
330 #[pyo3(name = "ts_event")]
331 fn py_ts_event(&self) -> u64 {
332 self.ts_event.as_u64()
333 }
334
335 #[getter]
336 #[pyo3(name = "ts_init")]
337 fn py_ts_init(&self) -> u64 {
338 self.ts_init.as_u64()
339 }
340
341 #[staticmethod]
342 #[pyo3(name = "fully_qualified_name")]
343 fn py_fully_qualified_name() -> String {
344 format!("{}:{}", PY_MODULE_MODEL, stringify!(Bar))
345 }
346
347 #[staticmethod]
348 #[pyo3(name = "get_metadata")]
349 fn py_get_metadata(
350 bar_type: &BarType,
351 price_precision: u8,
352 size_precision: u8,
353 ) -> PyResult<HashMap<String, String>> {
354 Ok(Self::get_metadata(
355 bar_type,
356 price_precision,
357 size_precision,
358 ))
359 }
360
361 #[staticmethod]
362 #[pyo3(name = "get_fields")]
363 fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
364 let py_dict = PyDict::new(py);
365 for (k, v) in Self::get_fields() {
366 py_dict.set_item(k, v)?;
367 }
368
369 Ok(py_dict)
370 }
371
372 #[staticmethod]
374 #[pyo3(name = "from_dict")]
375 fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
376 from_dict_pyo3(py, values)
377 }
378
379 #[staticmethod]
380 #[pyo3(name = "from_json")]
381 fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
382 Self::from_json_bytes(&data).map_err(to_pyvalue_err)
383 }
384
385 #[staticmethod]
386 #[pyo3(name = "from_msgpack")]
387 fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
388 Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
389 }
390
391 #[pyo3(name = "as_pycapsule")]
407 fn py_as_pycapsule(&self, py: Python<'_>) -> Py<PyAny> {
408 data_to_pycapsule(py, Data::Bar(*self))
409 }
410
411 #[pyo3(name = "to_dict")]
413 fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
414 to_dict_pyo3(py, self)
415 }
416
417 #[pyo3(name = "to_json_bytes")]
419 fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
420 self.to_json_bytes().unwrap().into_py_any_unwrap(py)
422 }
423
424 #[pyo3(name = "to_msgpack_bytes")]
426 fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
427 self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use nautilus_core::python::IntoPyObjectNautilusExt;
435 use pyo3::Python;
436 use rstest::rstest;
437
438 use crate::{
439 data::{Bar, BarType},
440 types::{Price, Quantity},
441 };
442
443 #[rstest]
444 #[case("10.0000", "10.0010", "10.0020", "10.0005")] #[case("10.0000", "10.0010", "10.0005", "10.0030")] #[case("10.0000", "9.9990", "9.9980", "9.9995")] #[case("10.0000", "10.0010", "10.0015", "10.0020")] #[case("10.0000", "10.0000", "10.0001", "10.0002")] fn test_bar_py_new_invalid(
450 #[case] open: &str,
451 #[case] high: &str,
452 #[case] low: &str,
453 #[case] close: &str,
454 ) {
455 let bar_type = BarType::from("AUDUSD.SIM-1-MINUTE-LAST-INTERNAL");
456 let open = Price::from(open);
457 let high = Price::from(high);
458 let low = Price::from(low);
459 let close = Price::from(close);
460 let volume = Quantity::from(100_000);
461 let ts_event = 0;
462 let ts_init = 1;
463
464 let result = Bar::py_new(bar_type, open, high, low, close, volume, ts_event, ts_init);
465 assert!(result.is_err());
466 }
467
468 #[rstest]
469 fn test_bar_py_new() {
470 let bar_type = BarType::from("AUDUSD.SIM-1-MINUTE-LAST-INTERNAL");
471 let open = Price::from("1.00005");
472 let high = Price::from("1.00010");
473 let low = Price::from("1.00000");
474 let close = Price::from("1.00007");
475 let volume = Quantity::from(100_000);
476 let ts_event = 0;
477 let ts_init = 1;
478
479 let result = Bar::py_new(bar_type, open, high, low, close, volume, ts_event, ts_init);
480 assert!(result.is_ok());
481 }
482
483 #[rstest]
484 fn test_to_dict() {
485 let bar = Bar::default();
486
487 Python::initialize();
488 Python::attach(|py| {
489 let dict_string = bar.py_to_dict(py).unwrap().to_string();
490 let expected_string = r"{'type': 'Bar', 'bar_type': 'AUDUSD.SIM-1-MINUTE-LAST-INTERNAL', 'open': '1.00010', 'high': '1.00020', 'low': '1.00000', 'close': '1.00010', 'volume': '100000', 'ts_event': 0, 'ts_init': 0}";
491 assert_eq!(dict_string, expected_string);
492 });
493 }
494
495 #[rstest]
496 fn test_as_from_dict() {
497 let bar = Bar::default();
498
499 Python::initialize();
500 Python::attach(|py| {
501 let dict = bar.py_to_dict(py).unwrap();
502 let parsed = Bar::py_from_dict(py, dict).unwrap();
503 assert_eq!(parsed, bar);
504 });
505 }
506
507 #[rstest]
508 fn test_from_pyobject() {
509 let bar = Bar::default();
510
511 Python::initialize();
512 Python::attach(|py| {
513 let bar_pyobject = bar.into_py_any_unwrap(py);
514 let parsed_bar = Bar::from_pyobject(bar_pyobject.bind(py)).unwrap();
515 assert_eq!(parsed_bar, bar);
516 });
517 }
518}