nautilus_analysis/python/
analyzer.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use nautilus_core::{UnixNanos, python::to_pyvalue_err};
19use nautilus_model::{
20    identifiers::PositionId,
21    position::Position,
22    types::{Currency, Money},
23};
24use pyo3::{exceptions::PyValueError, prelude::*};
25
26use crate::{
27    analyzer::PortfolioAnalyzer,
28    statistics::{
29        expectancy::Expectancy, long_ratio::LongRatio, loser_avg::AvgLoser, loser_max::MaxLoser,
30        loser_min::MinLoser, profit_factor::ProfitFactor, returns_avg::ReturnsAverage,
31        returns_avg_loss::ReturnsAverageLoss, returns_avg_win::ReturnsAverageWin,
32        returns_volatility::ReturnsVolatility, risk_return_ratio::RiskReturnRatio,
33        sharpe_ratio::SharpeRatio, sortino_ratio::SortinoRatio, win_rate::WinRate,
34        winner_avg::AvgWinner, winner_max::MaxWinner, winner_min::MinWinner,
35    },
36};
37
38#[pymethods]
39impl PortfolioAnalyzer {
40    #[new]
41    #[must_use]
42    pub fn py_new() -> Self {
43        Self::new()
44    }
45
46    fn __repr__(&self) -> String {
47        format!("PortfolioAnalyzer(currencies={})", self.currencies().len())
48    }
49
50    #[pyo3(name = "currencies")]
51    fn py_currencies(&self) -> Vec<Currency> {
52        self.currencies().into_iter().copied().collect()
53    }
54
55    #[pyo3(name = "get_performance_stats_returns")]
56    fn py_get_performance_stats_returns(&self) -> HashMap<String, f64> {
57        self.get_performance_stats_returns().into_iter().collect()
58    }
59
60    #[pyo3(name = "get_performance_stats_pnls")]
61    fn py_get_performance_stats_pnls(
62        &self,
63        currency: Option<&Currency>,
64        unrealized_pnl: Option<&Money>,
65    ) -> PyResult<HashMap<String, f64>> {
66        self.get_performance_stats_pnls(currency, unrealized_pnl)
67            .map(|m| m.into_iter().collect())
68            .map_err(to_pyvalue_err)
69    }
70
71    #[pyo3(name = "get_performance_stats_general")]
72    fn py_get_performance_stats_general(&self) -> HashMap<String, f64> {
73        self.get_performance_stats_general().into_iter().collect()
74    }
75
76    #[pyo3(name = "add_return")]
77    fn py_add_return(&mut self, timestamp: u64, value: f64) {
78        self.add_return(UnixNanos::from(timestamp), value);
79    }
80
81    #[pyo3(name = "reset")]
82    fn py_reset(&mut self) {
83        self.reset();
84    }
85
86    #[pyo3(name = "register_statistic")]
87    fn py_register_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
88        let type_name = statistic
89            .getattr(py, "__class__")?
90            .getattr(py, "__name__")?
91            .extract::<String>(py)?;
92
93        match type_name.as_str() {
94            "MaxWinner" => {
95                let stat = statistic.extract::<MaxWinner>(py)?;
96                self.register_statistic(Arc::new(stat));
97            }
98            "MinWinner" => {
99                let stat = statistic.extract::<MinWinner>(py)?;
100                self.register_statistic(Arc::new(stat));
101            }
102            "AvgWinner" => {
103                let stat = statistic.extract::<AvgWinner>(py)?;
104                self.register_statistic(Arc::new(stat));
105            }
106            "MaxLoser" => {
107                let stat = statistic.extract::<MaxLoser>(py)?;
108                self.register_statistic(Arc::new(stat));
109            }
110            "MinLoser" => {
111                let stat = statistic.extract::<MinLoser>(py)?;
112                self.register_statistic(Arc::new(stat));
113            }
114            "AvgLoser" => {
115                let stat = statistic.extract::<AvgLoser>(py)?;
116                self.register_statistic(Arc::new(stat));
117            }
118            "Expectancy" => {
119                let stat = statistic.extract::<Expectancy>(py)?;
120                self.register_statistic(Arc::new(stat));
121            }
122            "WinRate" => {
123                let stat = statistic.extract::<WinRate>(py)?;
124                self.register_statistic(Arc::new(stat));
125            }
126            "ReturnsVolatility" => {
127                let stat = statistic.extract::<ReturnsVolatility>(py)?;
128                self.register_statistic(Arc::new(stat));
129            }
130            "ReturnsAverage" => {
131                let stat = statistic.extract::<ReturnsAverage>(py)?;
132                self.register_statistic(Arc::new(stat));
133            }
134            "ReturnsAverageLoss" => {
135                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
136                self.register_statistic(Arc::new(stat));
137            }
138            "ReturnsAverageWin" => {
139                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
140                self.register_statistic(Arc::new(stat));
141            }
142            "SharpeRatio" => {
143                let stat = statistic.extract::<SharpeRatio>(py)?;
144                self.register_statistic(Arc::new(stat));
145            }
146            "SortinoRatio" => {
147                let stat = statistic.extract::<SortinoRatio>(py)?;
148                self.register_statistic(Arc::new(stat));
149            }
150            "ProfitFactor" => {
151                let stat = statistic.extract::<ProfitFactor>(py)?;
152                self.register_statistic(Arc::new(stat));
153            }
154            "RiskReturnRatio" => {
155                let stat = statistic.extract::<RiskReturnRatio>(py)?;
156                self.register_statistic(Arc::new(stat));
157            }
158            "LongRatio" => {
159                let stat = statistic.extract::<LongRatio>(py)?;
160                self.register_statistic(Arc::new(stat));
161            }
162            _ => {
163                return Err(PyValueError::new_err(format!(
164                    "Unknown statistic type: {type_name}"
165                )));
166            }
167        }
168
169        Ok(())
170    }
171
172    #[pyo3(name = "deregister_statistic")]
173    fn py_deregister_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
174        let type_name = statistic
175            .getattr(py, "__class__")?
176            .getattr(py, "__name__")?
177            .extract::<String>(py)?;
178
179        match type_name.as_str() {
180            "MaxWinner" => {
181                let stat = statistic.extract::<MaxWinner>(py)?;
182                self.deregister_statistic(Arc::new(stat));
183            }
184            "MinWinner" => {
185                let stat = statistic.extract::<MinWinner>(py)?;
186                self.deregister_statistic(Arc::new(stat));
187            }
188            "AvgWinner" => {
189                let stat = statistic.extract::<AvgWinner>(py)?;
190                self.deregister_statistic(Arc::new(stat));
191            }
192            "MaxLoser" => {
193                let stat = statistic.extract::<MaxLoser>(py)?;
194                self.deregister_statistic(Arc::new(stat));
195            }
196            "MinLoser" => {
197                let stat = statistic.extract::<MinLoser>(py)?;
198                self.deregister_statistic(Arc::new(stat));
199            }
200            "AvgLoser" => {
201                let stat = statistic.extract::<AvgLoser>(py)?;
202                self.deregister_statistic(Arc::new(stat));
203            }
204            "Expectancy" => {
205                let stat = statistic.extract::<Expectancy>(py)?;
206                self.deregister_statistic(Arc::new(stat));
207            }
208            "WinRate" => {
209                let stat = statistic.extract::<WinRate>(py)?;
210                self.deregister_statistic(Arc::new(stat));
211            }
212            "ReturnsVolatility" => {
213                let stat = statistic.extract::<ReturnsVolatility>(py)?;
214                self.deregister_statistic(Arc::new(stat));
215            }
216            "ReturnsAverage" => {
217                let stat = statistic.extract::<ReturnsAverage>(py)?;
218                self.deregister_statistic(Arc::new(stat));
219            }
220            "ReturnsAverageLoss" => {
221                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
222                self.deregister_statistic(Arc::new(stat));
223            }
224            "ReturnsAverageWin" => {
225                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
226                self.deregister_statistic(Arc::new(stat));
227            }
228            "SharpeRatio" => {
229                let stat = statistic.extract::<SharpeRatio>(py)?;
230                self.deregister_statistic(Arc::new(stat));
231            }
232            "SortinoRatio" => {
233                let stat = statistic.extract::<SortinoRatio>(py)?;
234                self.deregister_statistic(Arc::new(stat));
235            }
236            "ProfitFactor" => {
237                let stat = statistic.extract::<ProfitFactor>(py)?;
238                self.deregister_statistic(Arc::new(stat));
239            }
240            "RiskReturnRatio" => {
241                let stat = statistic.extract::<RiskReturnRatio>(py)?;
242                self.deregister_statistic(Arc::new(stat));
243            }
244            "LongRatio" => {
245                let stat = statistic.extract::<LongRatio>(py)?;
246                self.deregister_statistic(Arc::new(stat));
247            }
248            _ => {
249                return Err(PyValueError::new_err(format!(
250                    "Unknown statistic type: {type_name}"
251                )));
252            }
253        }
254
255        Ok(())
256    }
257
258    #[pyo3(name = "deregister_statistics")]
259    fn py_deregister_statistics(&mut self) {
260        self.deregister_statistics();
261    }
262
263    #[pyo3(name = "add_positions")]
264    fn py_add_positions(&mut self, py: Python, positions: Vec<Py<PyAny>>) -> PyResult<()> {
265        // Extract Position objects from Cython wrappers
266        let positions: Vec<Position> = positions
267            .iter()
268            .map(|p| {
269                // Try to get the underlying Rust Position
270                // For now, we'll need to handle Cython Position by accessing its _mem field
271                p.getattr(py, "_mem")?
272                    .extract::<Position>(py)
273                    .map_err(Into::into)
274            })
275            .collect::<PyResult<Vec<Position>>>()?;
276
277        self.add_positions(&positions);
278        Ok(())
279    }
280
281    #[pyo3(name = "add_trade")]
282    fn py_add_trade(&mut self, position_id: &PositionId, realized_pnl: &Money) {
283        self.add_trade(position_id, realized_pnl);
284    }
285
286    // Note: calculate_statistics is not exposed to Python because it requires
287    // complex conversions of Account and dict types. Use the Python analyzer.py wrapper instead.
288
289    #[pyo3(name = "statistic")]
290    fn py_statistic(&self, name: &str) -> Option<String> {
291        self.statistic(name).map(|s| s.name())
292    }
293
294    #[pyo3(name = "returns")]
295    fn py_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
296        // Convert BTreeMap<UnixNanos, f64> to Python dict
297        let dict = pyo3::types::PyDict::new(py);
298        for (timestamp, value) in self.returns() {
299            dict.set_item(timestamp.as_u64(), value)?;
300        }
301        Ok(dict.into())
302    }
303
304    #[pyo3(name = "realized_pnls")]
305    fn py_realized_pnls(&self, py: Python, currency: Option<&Currency>) -> PyResult<Py<PyAny>> {
306        match self.realized_pnls(currency) {
307            Some(pnls) => {
308                // Convert Vec<(PositionId, f64)> to Python list of tuples or dict
309                let dict = pyo3::types::PyDict::new(py);
310                for (position_id, pnl) in pnls {
311                    dict.set_item(position_id.to_string(), pnl)?;
312                }
313                Ok(dict.into())
314            }
315            None => Ok(py.None()),
316        }
317    }
318
319    #[pyo3(name = "total_pnl")]
320    fn py_total_pnl(
321        &self,
322        currency: Option<&Currency>,
323        unrealized_pnl: Option<&Money>,
324    ) -> PyResult<f64> {
325        self.total_pnl(currency, unrealized_pnl)
326            .map_err(to_pyvalue_err)
327    }
328
329    #[pyo3(name = "total_pnl_percentage")]
330    fn py_total_pnl_percentage(
331        &self,
332        currency: Option<&Currency>,
333        unrealized_pnl: Option<&Money>,
334    ) -> PyResult<f64> {
335        self.total_pnl_percentage(currency, unrealized_pnl)
336            .map_err(to_pyvalue_err)
337    }
338
339    #[pyo3(name = "get_stats_pnls_formatted")]
340    fn py_get_stats_pnls_formatted(
341        &self,
342        currency: Option<&Currency>,
343        unrealized_pnl: Option<&Money>,
344    ) -> PyResult<Vec<String>> {
345        self.get_stats_pnls_formatted(currency, unrealized_pnl)
346            .map_err(PyValueError::new_err)
347    }
348
349    #[pyo3(name = "get_stats_returns_formatted")]
350    fn py_get_stats_returns_formatted(&self) -> Vec<String> {
351        self.get_stats_returns_formatted()
352    }
353
354    #[pyo3(name = "get_stats_general_formatted")]
355    fn py_get_stats_general_formatted(&self) -> Vec<String> {
356        self.get_stats_general_formatted()
357    }
358}