1use std::{
17 collections::{HashMap, hash_map::DefaultHasher},
18 hash::{Hash, Hasher},
19 str::FromStr,
20};
21
22use nautilus_core::{
23 UnixNanos,
24 python::{
25 IntoPyObjectNautilusExt,
26 serialization::{from_dict_pyo3, to_dict_pyo3},
27 to_pyvalue_err,
28 },
29 serialization::{
30 Serializable,
31 msgpack::{FromMsgPack, ToMsgPack},
32 },
33};
34use pyo3::{
35 IntoPyObjectExt,
36 prelude::*,
37 pyclass::CompareOp,
38 types::{PyDict, PyInt, PyString, PyTuple},
39};
40
41use crate::{
42 data::{IndexPriceUpdate, MarkPriceUpdate},
43 identifiers::InstrumentId,
44 python::common::PY_MODULE_MODEL,
45 types::price::{Price, PriceRaw},
46};
47
48impl MarkPriceUpdate {
49 pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
55 let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
56 let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
57 let instrument_id =
58 InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
59
60 let value_py: Bound<'_, PyAny> = obj.getattr("value")?.extract()?;
61 let value_raw: PriceRaw = value_py.getattr("raw")?.extract()?;
62 let value_prec: u8 = value_py.getattr("precision")?.extract()?;
63 let value = Price::from_raw(value_raw, value_prec);
64
65 let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
66 let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
67
68 Ok(Self::new(
69 instrument_id,
70 value,
71 ts_event.into(),
72 ts_init.into(),
73 ))
74 }
75}
76
77#[pymethods]
78impl MarkPriceUpdate {
79 #[new]
80 fn py_new(
81 instrument_id: InstrumentId,
82 value: Price,
83 ts_event: u64,
84 ts_init: u64,
85 ) -> PyResult<Self> {
86 Ok(Self::new(
87 instrument_id,
88 value,
89 ts_event.into(),
90 ts_init.into(),
91 ))
92 }
93
94 fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
95 let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
96 let binding = py_tuple.get_item(0)?;
97 let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
98 let value_raw = py_tuple
99 .get_item(1)?
100 .cast::<PyInt>()?
101 .extract::<PriceRaw>()?;
102 let value_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
103
104 let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
105 let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
106
107 self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
108 self.value = Price::from_raw(value_raw, value_prec);
109 self.ts_event = ts_event.into();
110 self.ts_init = ts_init.into();
111
112 Ok(())
113 }
114
115 fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
116 (
117 self.instrument_id.to_string(),
118 self.value.raw,
119 self.value.precision,
120 self.ts_event.as_u64(),
121 self.ts_init.as_u64(),
122 )
123 .into_py_any(py)
124 }
125
126 fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
127 let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
128 let state = self.__getstate__(py)?;
129 (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
130 }
131
132 #[staticmethod]
133 fn _safe_constructor() -> Self {
134 Self::new(
135 InstrumentId::from("NULL.NULL"),
136 Price::zero(0),
137 UnixNanos::default(),
138 UnixNanos::default(),
139 )
140 }
141
142 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
143 match op {
144 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
145 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
146 _ => py.NotImplemented(),
147 }
148 }
149
150 fn __hash__(&self) -> isize {
151 let mut h = DefaultHasher::new();
152 self.hash(&mut h);
153 h.finish() as isize
154 }
155
156 fn __repr__(&self) -> String {
157 format!("{}({})", stringify!(MarkPriceUpdate), self)
158 }
159
160 fn __str__(&self) -> String {
161 self.to_string()
162 }
163
164 #[getter]
165 #[pyo3(name = "instrument_id")]
166 fn py_instrument_id(&self) -> InstrumentId {
167 self.instrument_id
168 }
169
170 #[getter]
171 #[pyo3(name = "value")]
172 fn py_value(&self) -> Price {
173 self.value
174 }
175
176 #[getter]
177 #[pyo3(name = "ts_event")]
178 fn py_ts_event(&self) -> u64 {
179 self.ts_event.as_u64()
180 }
181
182 #[getter]
183 #[pyo3(name = "ts_init")]
184 fn py_ts_init(&self) -> u64 {
185 self.ts_init.as_u64()
186 }
187
188 #[staticmethod]
189 #[pyo3(name = "fully_qualified_name")]
190 fn py_fully_qualified_name() -> String {
191 format!("{}:{}", PY_MODULE_MODEL, stringify!(MarkPriceUpdate))
192 }
193
194 #[staticmethod]
195 #[pyo3(name = "get_metadata")]
196 fn py_get_metadata(
197 instrument_id: &InstrumentId,
198 price_precision: u8,
199 ) -> PyResult<HashMap<String, String>> {
200 Ok(Self::get_metadata(instrument_id, price_precision))
201 }
202
203 #[staticmethod]
204 #[pyo3(name = "get_fields")]
205 fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
206 let py_dict = PyDict::new(py);
207 for (k, v) in Self::get_fields() {
208 py_dict.set_item(k, v)?;
209 }
210
211 Ok(py_dict)
212 }
213
214 #[staticmethod]
216 #[pyo3(name = "from_dict")]
217 fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
218 from_dict_pyo3(py, values)
219 }
220
221 #[staticmethod]
222 #[pyo3(name = "from_json")]
223 fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
224 Self::from_json_bytes(&data).map_err(to_pyvalue_err)
225 }
226
227 #[staticmethod]
228 #[pyo3(name = "from_msgpack")]
229 fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
230 Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
231 }
232
233 #[pyo3(name = "to_dict")]
235 fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
236 to_dict_pyo3(py, self)
237 }
238
239 #[pyo3(name = "to_json_bytes")]
241 fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
242 self.to_json_bytes().unwrap().into_py_any_unwrap(py)
244 }
245
246 #[pyo3(name = "to_msgpack_bytes")]
248 fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
249 self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
251 }
252}
253
254impl IndexPriceUpdate {
255 pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
261 let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
262 let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
263 let instrument_id =
264 InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
265
266 let value_py: Bound<'_, PyAny> = obj.getattr("value")?.extract()?;
267 let value_raw: PriceRaw = value_py.getattr("raw")?.extract()?;
268 let value_prec: u8 = value_py.getattr("precision")?.extract()?;
269 let value = Price::from_raw(value_raw, value_prec);
270
271 let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
272 let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
273
274 Ok(Self::new(
275 instrument_id,
276 value,
277 ts_event.into(),
278 ts_init.into(),
279 ))
280 }
281}
282
283#[pymethods]
284impl IndexPriceUpdate {
285 #[new]
286 fn py_new(
287 instrument_id: InstrumentId,
288 value: Price,
289 ts_event: u64,
290 ts_init: u64,
291 ) -> PyResult<Self> {
292 Ok(Self::new(
293 instrument_id,
294 value,
295 ts_event.into(),
296 ts_init.into(),
297 ))
298 }
299
300 fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
301 let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
302 let binding = py_tuple.get_item(0)?;
303 let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
304 let value_raw = py_tuple
305 .get_item(1)?
306 .cast::<PyInt>()?
307 .extract::<PriceRaw>()?;
308 let value_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
309
310 let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
311 let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
312
313 self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
314 self.value = Price::from_raw(value_raw, value_prec);
315 self.ts_event = ts_event.into();
316 self.ts_init = ts_init.into();
317
318 Ok(())
319 }
320
321 fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
322 (
323 self.instrument_id.to_string(),
324 self.value.raw,
325 self.value.precision,
326 self.ts_event.as_u64(),
327 self.ts_init.as_u64(),
328 )
329 .into_py_any(py)
330 }
331
332 fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
333 let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
334 let state = self.__getstate__(py)?;
335 (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
336 }
337
338 #[staticmethod]
339 fn _safe_constructor() -> Self {
340 Self::new(
341 InstrumentId::from("NULL.NULL"),
342 Price::zero(0),
343 UnixNanos::default(),
344 UnixNanos::default(),
345 )
346 }
347
348 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
349 match op {
350 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
351 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
352 _ => py.NotImplemented(),
353 }
354 }
355
356 fn __hash__(&self) -> isize {
357 let mut h = DefaultHasher::new();
358 self.hash(&mut h);
359 h.finish() as isize
360 }
361
362 fn __repr__(&self) -> String {
363 format!("{}({})", stringify!(IndexPriceUpdate), self)
364 }
365
366 fn __str__(&self) -> String {
367 self.to_string()
368 }
369
370 #[getter]
371 #[pyo3(name = "instrument_id")]
372 fn py_instrument_id(&self) -> InstrumentId {
373 self.instrument_id
374 }
375
376 #[getter]
377 #[pyo3(name = "value")]
378 fn py_value(&self) -> Price {
379 self.value
380 }
381
382 #[getter]
383 #[pyo3(name = "ts_event")]
384 fn py_ts_event(&self) -> u64 {
385 self.ts_event.as_u64()
386 }
387
388 #[getter]
389 #[pyo3(name = "ts_init")]
390 fn py_ts_init(&self) -> u64 {
391 self.ts_init.as_u64()
392 }
393
394 #[staticmethod]
395 #[pyo3(name = "fully_qualified_name")]
396 fn py_fully_qualified_name() -> String {
397 format!("{}:{}", PY_MODULE_MODEL, stringify!(IndexPriceUpdate))
398 }
399
400 #[staticmethod]
401 #[pyo3(name = "get_metadata")]
402 fn py_get_metadata(
403 instrument_id: &InstrumentId,
404 price_precision: u8,
405 ) -> PyResult<HashMap<String, String>> {
406 Ok(Self::get_metadata(instrument_id, price_precision))
407 }
408
409 #[staticmethod]
410 #[pyo3(name = "get_fields")]
411 fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
412 let py_dict = PyDict::new(py);
413 for (k, v) in Self::get_fields() {
414 py_dict.set_item(k, v)?;
415 }
416
417 Ok(py_dict)
418 }
419
420 #[staticmethod]
422 #[pyo3(name = "from_dict")]
423 fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
424 from_dict_pyo3(py, values)
425 }
426
427 #[staticmethod]
428 #[pyo3(name = "from_json")]
429 fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
430 Self::from_json_bytes(&data).map_err(to_pyvalue_err)
431 }
432
433 #[staticmethod]
434 #[pyo3(name = "from_msgpack")]
435 fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
436 Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
437 }
438
439 #[pyo3(name = "to_dict")]
441 fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
442 to_dict_pyo3(py, self)
443 }
444
445 #[pyo3(name = "to_json_bytes")]
447 fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
448 self.to_json_bytes().unwrap().into_py_any_unwrap(py)
450 }
451
452 #[pyo3(name = "to_msgpack_bytes")]
454 fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
455 self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use nautilus_core::python::IntoPyObjectNautilusExt;
463 use pyo3::Python;
464 use rstest::{fixture, rstest};
465
466 use super::*;
467 use crate::{identifiers::InstrumentId, types::Price};
468
469 #[fixture]
470 fn mark_price() -> MarkPriceUpdate {
471 MarkPriceUpdate::new(
472 InstrumentId::from("BTC-USDT.OKX"),
473 Price::from("100_000.00"),
474 UnixNanos::from(1),
475 UnixNanos::from(2),
476 )
477 }
478
479 #[fixture]
480 fn index_price() -> IndexPriceUpdate {
481 IndexPriceUpdate::new(
482 InstrumentId::from("BTC-USDT.OKX"),
483 Price::from("100_000.00"),
484 UnixNanos::from(1),
485 UnixNanos::from(2),
486 )
487 }
488
489 #[rstest]
490 fn test_mark_price_to_dict(mark_price: MarkPriceUpdate) {
491 Python::initialize();
492 Python::attach(|py| {
493 let dict_string = mark_price.py_to_dict(py).unwrap().to_string();
494 let expected_string = r"{'type': 'MarkPriceUpdate', 'instrument_id': 'BTC-USDT.OKX', 'value': '100000.00', 'ts_event': 1, 'ts_init': 2}";
495 assert_eq!(dict_string, expected_string);
496 });
497 }
498
499 #[rstest]
500 fn test_mark_price_from_dict(mark_price: MarkPriceUpdate) {
501 Python::initialize();
502 Python::attach(|py| {
503 let dict = mark_price.py_to_dict(py).unwrap();
504 let parsed = MarkPriceUpdate::py_from_dict(py, dict).unwrap();
505 assert_eq!(parsed, mark_price);
506 });
507 }
508
509 #[rstest]
510 fn test_mark_price_from_pyobject(mark_price: MarkPriceUpdate) {
511 Python::initialize();
512 Python::attach(|py| {
513 let tick_pyobject = mark_price.into_py_any_unwrap(py);
514 let parsed_tick = MarkPriceUpdate::from_pyobject(tick_pyobject.bind(py)).unwrap();
515 assert_eq!(parsed_tick, mark_price);
516 });
517 }
518
519 #[rstest]
520 fn test_index_price_to_dict(index_price: IndexPriceUpdate) {
521 Python::initialize();
522 Python::attach(|py| {
523 let dict_string = index_price.py_to_dict(py).unwrap().to_string();
524 let expected_string = r"{'type': 'IndexPriceUpdate', 'instrument_id': 'BTC-USDT.OKX', 'value': '100000.00', 'ts_event': 1, 'ts_init': 2}";
525 assert_eq!(dict_string, expected_string);
526 });
527 }
528
529 #[rstest]
530 fn test_index_price_from_dict(index_price: IndexPriceUpdate) {
531 Python::initialize();
532 Python::attach(|py| {
533 let dict = index_price.py_to_dict(py).unwrap();
534 let parsed = IndexPriceUpdate::py_from_dict(py, dict).unwrap();
535 assert_eq!(parsed, index_price);
536 });
537 }
538
539 #[rstest]
540 fn test_index_price_from_pyobject(index_price: IndexPriceUpdate) {
541 Python::initialize();
542 Python::attach(|py| {
543 let tick_pyobject = index_price.into_py_any_unwrap(py);
544 let parsed_tick = IndexPriceUpdate::from_pyobject(tick_pyobject.bind(py)).unwrap();
545 assert_eq!(parsed_tick, index_price);
546 });
547 }
548}