nautilus_model/python/data/
trade.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, hash_map::DefaultHasher},
18    hash::{Hash, Hasher},
19    str::FromStr,
20};
21
22use nautilus_core::{
23    UnixNanos,
24    python::{
25        IntoPyObjectNautilusExt,
26        serialization::{from_dict_pyo3, to_dict_pyo3},
27        to_pyvalue_err,
28    },
29    serialization::{
30        Serializable,
31        msgpack::{FromMsgPack, ToMsgPack},
32    },
33};
34use pyo3::{
35    IntoPyObjectExt,
36    prelude::*,
37    pyclass::CompareOp,
38    types::{PyDict, PyInt, PyString, PyTuple},
39};
40
41use super::data_to_pycapsule;
42use crate::{
43    data::{Data, TradeTick},
44    enums::{AggressorSide, FromU8},
45    identifiers::{InstrumentId, TradeId},
46    python::common::PY_MODULE_MODEL,
47    types::{
48        price::{Price, PriceRaw},
49        quantity::{Quantity, QuantityRaw},
50    },
51};
52
53impl TradeTick {
54    /// Creates a new [`TradeTick`] from a Python object.
55    ///
56    /// # Panics
57    ///
58    /// Panics if converting `aggressor_side_u8` to `AggressorSide` fails.
59    ///
60    /// # Errors
61    ///
62    /// Returns a `PyErr` if attribute extraction or type conversion fails.
63    pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
64        let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
65        let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
66        let instrument_id =
67            InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
68
69        let price_py: Bound<'_, PyAny> = obj.getattr("price")?.extract()?;
70        let price_raw: PriceRaw = price_py.getattr("raw")?.extract()?;
71        let price_prec: u8 = price_py.getattr("precision")?.extract()?;
72        let price = Price::from_raw(price_raw, price_prec);
73
74        let size_py: Bound<'_, PyAny> = obj.getattr("size")?.extract()?;
75        let size_raw: QuantityRaw = size_py.getattr("raw")?.extract()?;
76        let size_prec: u8 = size_py.getattr("precision")?.extract()?;
77        let size = Quantity::from_raw(size_raw, size_prec);
78
79        let aggressor_side_obj: Bound<'_, PyAny> = obj.getattr("aggressor_side")?.extract()?;
80        let aggressor_side_u8 = aggressor_side_obj.getattr("value")?.extract()?;
81        let aggressor_side = AggressorSide::from_u8(aggressor_side_u8).unwrap();
82
83        let trade_id_obj: Bound<'_, PyAny> = obj.getattr("trade_id")?.extract()?;
84        let trade_id_str: String = trade_id_obj.getattr("value")?.extract()?;
85        let trade_id = TradeId::from(trade_id_str.as_str());
86
87        let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
88        let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
89
90        Ok(Self::new(
91            instrument_id,
92            price,
93            size,
94            aggressor_side,
95            trade_id,
96            ts_event.into(),
97            ts_init.into(),
98        ))
99    }
100}
101
102#[pymethods]
103impl TradeTick {
104    #[new]
105    fn py_new(
106        instrument_id: InstrumentId,
107        price: Price,
108        size: Quantity,
109        aggressor_side: AggressorSide,
110        trade_id: TradeId,
111        ts_event: u64,
112        ts_init: u64,
113    ) -> PyResult<Self> {
114        Self::new_checked(
115            instrument_id,
116            price,
117            size,
118            aggressor_side,
119            trade_id,
120            ts_event.into(),
121            ts_init.into(),
122        )
123        .map_err(to_pyvalue_err)
124    }
125
126    fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
127        let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
128        let binding = py_tuple.get_item(0)?;
129        let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
130        let price_raw = py_tuple
131            .get_item(1)?
132            .cast::<PyInt>()?
133            .extract::<PriceRaw>()?;
134        let price_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
135        let size_raw = py_tuple
136            .get_item(3)?
137            .cast::<PyInt>()?
138            .extract::<QuantityRaw>()?;
139        let size_prec = py_tuple.get_item(4)?.cast::<PyInt>()?.extract::<u8>()?;
140
141        let aggressor_side_u8 = py_tuple.get_item(5)?.cast::<PyInt>()?.extract::<u8>()?;
142        let binding = py_tuple.get_item(6)?;
143        let trade_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
144        let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
145        let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
146
147        self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
148        self.price = Price::from_raw(price_raw, price_prec);
149        self.size = Quantity::from_raw(size_raw, size_prec);
150        self.aggressor_side = AggressorSide::from_u8(aggressor_side_u8).unwrap();
151        self.trade_id = TradeId::from(trade_id_str);
152        self.ts_event = ts_event.into();
153        self.ts_init = ts_init.into();
154
155        Ok(())
156    }
157
158    fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
159        (
160            self.instrument_id.to_string(),
161            self.price.raw,
162            self.price.precision,
163            self.size.raw,
164            self.size.precision,
165            self.aggressor_side as u8,
166            self.trade_id.to_string(),
167            self.ts_event.as_u64(),
168            self.ts_init.as_u64(),
169        )
170            .into_py_any(py)
171    }
172
173    fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
174        let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
175        let state = self.__getstate__(py)?;
176        (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
177    }
178
179    #[staticmethod]
180    fn _safe_constructor() -> Self {
181        Self::new(
182            InstrumentId::from("NULL.NULL"),
183            Price::zero(0),
184            Quantity::from(1), // size cannot be zero
185            AggressorSide::NoAggressor,
186            TradeId::from("NULL"),
187            UnixNanos::default(),
188            UnixNanos::default(),
189        )
190    }
191
192    fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
193        match op {
194            CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
195            CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
196            _ => py.NotImplemented(),
197        }
198    }
199
200    fn __hash__(&self) -> isize {
201        let mut h = DefaultHasher::new();
202        self.hash(&mut h);
203        h.finish() as isize
204    }
205
206    fn __repr__(&self) -> String {
207        format!("{}({})", stringify!(TradeTick), self)
208    }
209
210    fn __str__(&self) -> String {
211        self.to_string()
212    }
213
214    #[getter]
215    #[pyo3(name = "instrument_id")]
216    fn py_instrument_id(&self) -> InstrumentId {
217        self.instrument_id
218    }
219
220    #[getter]
221    #[pyo3(name = "price")]
222    fn py_price(&self) -> Price {
223        self.price
224    }
225
226    #[getter]
227    #[pyo3(name = "size")]
228    fn py_size(&self) -> Quantity {
229        self.size
230    }
231
232    #[getter]
233    #[pyo3(name = "aggressor_side")]
234    fn py_aggressor_side(&self) -> AggressorSide {
235        self.aggressor_side
236    }
237
238    #[getter]
239    #[pyo3(name = "trade_id")]
240    fn py_trade_id(&self) -> TradeId {
241        self.trade_id
242    }
243
244    #[getter]
245    #[pyo3(name = "ts_event")]
246    fn py_ts_event(&self) -> u64 {
247        self.ts_event.as_u64()
248    }
249
250    #[getter]
251    #[pyo3(name = "ts_init")]
252    fn py_ts_init(&self) -> u64 {
253        self.ts_init.as_u64()
254    }
255
256    #[staticmethod]
257    #[pyo3(name = "fully_qualified_name")]
258    fn py_fully_qualified_name() -> String {
259        format!("{}:{}", PY_MODULE_MODEL, stringify!(TradeTick))
260    }
261
262    #[staticmethod]
263    #[pyo3(name = "get_metadata")]
264    fn py_get_metadata(
265        instrument_id: &InstrumentId,
266        price_precision: u8,
267        size_precision: u8,
268    ) -> PyResult<HashMap<String, String>> {
269        Ok(Self::get_metadata(
270            instrument_id,
271            price_precision,
272            size_precision,
273        ))
274    }
275
276    #[staticmethod]
277    #[pyo3(name = "get_fields")]
278    fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
279        let py_dict = PyDict::new(py);
280        for (k, v) in Self::get_fields() {
281            py_dict.set_item(k, v)?;
282        }
283
284        Ok(py_dict)
285    }
286
287    /// Returns a new object from the given dictionary representation.
288    #[staticmethod]
289    #[pyo3(name = "from_dict")]
290    fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
291        from_dict_pyo3(py, values)
292    }
293
294    #[staticmethod]
295    #[pyo3(name = "from_json")]
296    fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
297        Self::from_json_bytes(&data).map_err(to_pyvalue_err)
298    }
299
300    #[staticmethod]
301    #[pyo3(name = "from_msgpack")]
302    fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
303        Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
304    }
305
306    /// Creates a `PyCapsule` containing a raw pointer to a `Data::Trade` object.
307    ///
308    /// This function takes the current object (assumed to be of a type that can be represented as
309    /// `Data::Trade`), and encapsulates a raw pointer to it within a `PyCapsule`.
310    ///
311    /// # Safety
312    ///
313    /// This function is safe as long as the following conditions are met:
314    /// - The `Data::Trade` object pointed to by the capsule must remain valid for the lifetime of the capsule.
315    /// - The consumer of the capsule must ensure proper handling to avoid dereferencing a dangling pointer.
316    ///
317    /// # Panics
318    ///
319    /// The function will panic if the `PyCapsule` creation fails, which can occur if the
320    /// `Data::Trade` object cannot be converted into a raw pointer.
321    #[pyo3(name = "as_pycapsule")]
322    fn py_as_pycapsule(&self, py: Python<'_>) -> Py<PyAny> {
323        data_to_pycapsule(py, Data::Trade(*self))
324    }
325
326    /// Return a dictionary representation of the object.
327    #[pyo3(name = "to_dict")]
328    fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
329        to_dict_pyo3(py, self)
330    }
331
332    /// Return JSON encoded bytes representation of the object.
333    #[pyo3(name = "to_json_bytes")]
334    fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
335        // SAFETY: Unwrap safe when serializing a valid object
336        self.to_json_bytes().unwrap().into_py_any_unwrap(py)
337    }
338
339    /// Return MsgPack encoded bytes representation of the object.
340    #[pyo3(name = "to_msgpack_bytes")]
341    fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
342        // SAFETY: Unwrap safe when serializing a valid object
343        self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use nautilus_core::python::IntoPyObjectNautilusExt;
350    use pyo3::Python;
351    use rstest::rstest;
352
353    use crate::{
354        data::{TradeTick, stubs::stub_trade_ethusdt_buyer},
355        enums::AggressorSide,
356        identifiers::{InstrumentId, TradeId},
357        types::{Price, Quantity},
358    };
359
360    #[rstest]
361    fn test_trade_tick_py_new_with_zero_size() {
362        let instrument_id = InstrumentId::from("ETH-USDT-SWAP.OKX");
363        let price = Price::from("10000.00");
364        let zero_size = Quantity::from(0);
365        let aggressor_side = AggressorSide::Buyer;
366        let trade_id = TradeId::from("123456789");
367        let ts_event = 1;
368        let ts_init = 2;
369
370        let result = TradeTick::py_new(
371            instrument_id,
372            price,
373            zero_size,
374            aggressor_side,
375            trade_id,
376            ts_event,
377            ts_init,
378        );
379
380        assert!(result.is_err());
381    }
382
383    #[rstest]
384    fn test_to_dict(stub_trade_ethusdt_buyer: TradeTick) {
385        let trade = stub_trade_ethusdt_buyer;
386
387        Python::initialize();
388        Python::attach(|py| {
389            let dict_string = trade.py_to_dict(py).unwrap().to_string();
390            let expected_string = r"{'type': 'TradeTick', 'instrument_id': 'ETHUSDT-PERP.BINANCE', 'price': '10000.0000', 'size': '1.00000000', 'aggressor_side': 'BUYER', 'trade_id': '123456789', 'ts_event': 0, 'ts_init': 1}";
391            assert_eq!(dict_string, expected_string);
392        });
393    }
394
395    #[rstest]
396    fn test_from_dict(stub_trade_ethusdt_buyer: TradeTick) {
397        let trade = stub_trade_ethusdt_buyer;
398
399        Python::initialize();
400        Python::attach(|py| {
401            let dict = trade.py_to_dict(py).unwrap();
402            let parsed = TradeTick::py_from_dict(py, dict).unwrap();
403            assert_eq!(parsed, trade);
404        });
405    }
406
407    #[rstest]
408    fn test_from_pyobject(stub_trade_ethusdt_buyer: TradeTick) {
409        let trade = stub_trade_ethusdt_buyer;
410
411        Python::initialize();
412        Python::attach(|py| {
413            let tick_pyobject = trade.into_py_any_unwrap(py);
414            let parsed_tick = TradeTick::from_pyobject(tick_pyobject.bind(py)).unwrap();
415            assert_eq!(parsed_tick, trade);
416        });
417    }
418}