nautilus_analysis/python/
analyzer.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use nautilus_core::{UnixNanos, python::to_pyvalue_err};
19use nautilus_model::{
20    identifiers::PositionId,
21    position::Position,
22    types::{Currency, Money},
23};
24use pyo3::{exceptions::PyValueError, prelude::*};
25
26use crate::{
27    analyzer::PortfolioAnalyzer,
28    statistics::{
29        expectancy::Expectancy, long_ratio::LongRatio, loser_avg::AvgLoser, loser_max::MaxLoser,
30        loser_min::MinLoser, profit_factor::ProfitFactor, returns_avg::ReturnsAverage,
31        returns_avg_loss::ReturnsAverageLoss, returns_avg_win::ReturnsAverageWin,
32        returns_volatility::ReturnsVolatility, risk_return_ratio::RiskReturnRatio,
33        sharpe_ratio::SharpeRatio, sortino_ratio::SortinoRatio, win_rate::WinRate,
34        winner_avg::AvgWinner, winner_max::MaxWinner, winner_min::MinWinner,
35    },
36};
37
38#[pymethods]
39impl PortfolioAnalyzer {
40    #[new]
41    #[must_use]
42    pub fn py_new() -> Self {
43        Self::new()
44    }
45
46    fn __repr__(&self) -> String {
47        format!("PortfolioAnalyzer(currencies={})", self.currencies().len())
48    }
49
50    #[pyo3(name = "currencies")]
51    fn py_currencies(&self) -> Vec<Currency> {
52        self.currencies().into_iter().copied().collect()
53    }
54
55    #[pyo3(name = "get_performance_stats_returns")]
56    fn py_get_performance_stats_returns(&self) -> HashMap<String, f64> {
57        self.get_performance_stats_returns()
58    }
59
60    #[pyo3(name = "get_performance_stats_pnls")]
61    fn py_get_performance_stats_pnls(
62        &self,
63        currency: Option<&Currency>,
64        unrealized_pnl: Option<&Money>,
65    ) -> PyResult<HashMap<String, f64>> {
66        self.get_performance_stats_pnls(currency, unrealized_pnl)
67            .map_err(to_pyvalue_err)
68    }
69
70    #[pyo3(name = "get_performance_stats_general")]
71    fn py_get_performance_stats_general(&self) -> HashMap<String, f64> {
72        self.get_performance_stats_general()
73    }
74
75    #[pyo3(name = "add_return")]
76    fn py_add_return(&mut self, timestamp: u64, value: f64) {
77        self.add_return(UnixNanos::from(timestamp), value);
78    }
79
80    #[pyo3(name = "reset")]
81    fn py_reset(&mut self) {
82        self.reset();
83    }
84
85    #[pyo3(name = "register_statistic")]
86    fn py_register_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
87        let type_name = statistic
88            .getattr(py, "__class__")?
89            .getattr(py, "__name__")?
90            .extract::<String>(py)?;
91
92        match type_name.as_str() {
93            "MaxWinner" => {
94                let stat = statistic.extract::<MaxWinner>(py)?;
95                self.register_statistic(Arc::new(stat));
96            }
97            "MinWinner" => {
98                let stat = statistic.extract::<MinWinner>(py)?;
99                self.register_statistic(Arc::new(stat));
100            }
101            "AvgWinner" => {
102                let stat = statistic.extract::<AvgWinner>(py)?;
103                self.register_statistic(Arc::new(stat));
104            }
105            "MaxLoser" => {
106                let stat = statistic.extract::<MaxLoser>(py)?;
107                self.register_statistic(Arc::new(stat));
108            }
109            "MinLoser" => {
110                let stat = statistic.extract::<MinLoser>(py)?;
111                self.register_statistic(Arc::new(stat));
112            }
113            "AvgLoser" => {
114                let stat = statistic.extract::<AvgLoser>(py)?;
115                self.register_statistic(Arc::new(stat));
116            }
117            "Expectancy" => {
118                let stat = statistic.extract::<Expectancy>(py)?;
119                self.register_statistic(Arc::new(stat));
120            }
121            "WinRate" => {
122                let stat = statistic.extract::<WinRate>(py)?;
123                self.register_statistic(Arc::new(stat));
124            }
125            "ReturnsVolatility" => {
126                let stat = statistic.extract::<ReturnsVolatility>(py)?;
127                self.register_statistic(Arc::new(stat));
128            }
129            "ReturnsAverage" => {
130                let stat = statistic.extract::<ReturnsAverage>(py)?;
131                self.register_statistic(Arc::new(stat));
132            }
133            "ReturnsAverageLoss" => {
134                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
135                self.register_statistic(Arc::new(stat));
136            }
137            "ReturnsAverageWin" => {
138                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
139                self.register_statistic(Arc::new(stat));
140            }
141            "SharpeRatio" => {
142                let stat = statistic.extract::<SharpeRatio>(py)?;
143                self.register_statistic(Arc::new(stat));
144            }
145            "SortinoRatio" => {
146                let stat = statistic.extract::<SortinoRatio>(py)?;
147                self.register_statistic(Arc::new(stat));
148            }
149            "ProfitFactor" => {
150                let stat = statistic.extract::<ProfitFactor>(py)?;
151                self.register_statistic(Arc::new(stat));
152            }
153            "RiskReturnRatio" => {
154                let stat = statistic.extract::<RiskReturnRatio>(py)?;
155                self.register_statistic(Arc::new(stat));
156            }
157            "LongRatio" => {
158                let stat = statistic.extract::<LongRatio>(py)?;
159                self.register_statistic(Arc::new(stat));
160            }
161            _ => {
162                return Err(PyValueError::new_err(format!(
163                    "Unknown statistic type: {}",
164                    type_name
165                )));
166            }
167        }
168
169        Ok(())
170    }
171
172    #[pyo3(name = "deregister_statistic")]
173    fn py_deregister_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
174        let type_name = statistic
175            .getattr(py, "__class__")?
176            .getattr(py, "__name__")?
177            .extract::<String>(py)?;
178
179        match type_name.as_str() {
180            "MaxWinner" => {
181                let stat = statistic.extract::<MaxWinner>(py)?;
182                self.deregister_statistic(Arc::new(stat));
183            }
184            "MinWinner" => {
185                let stat = statistic.extract::<MinWinner>(py)?;
186                self.deregister_statistic(Arc::new(stat));
187            }
188            "AvgWinner" => {
189                let stat = statistic.extract::<AvgWinner>(py)?;
190                self.deregister_statistic(Arc::new(stat));
191            }
192            "MaxLoser" => {
193                let stat = statistic.extract::<MaxLoser>(py)?;
194                self.deregister_statistic(Arc::new(stat));
195            }
196            "MinLoser" => {
197                let stat = statistic.extract::<MinLoser>(py)?;
198                self.deregister_statistic(Arc::new(stat));
199            }
200            "AvgLoser" => {
201                let stat = statistic.extract::<AvgLoser>(py)?;
202                self.deregister_statistic(Arc::new(stat));
203            }
204            "Expectancy" => {
205                let stat = statistic.extract::<Expectancy>(py)?;
206                self.deregister_statistic(Arc::new(stat));
207            }
208            "WinRate" => {
209                let stat = statistic.extract::<WinRate>(py)?;
210                self.deregister_statistic(Arc::new(stat));
211            }
212            "ReturnsVolatility" => {
213                let stat = statistic.extract::<ReturnsVolatility>(py)?;
214                self.deregister_statistic(Arc::new(stat));
215            }
216            "ReturnsAverage" => {
217                let stat = statistic.extract::<ReturnsAverage>(py)?;
218                self.deregister_statistic(Arc::new(stat));
219            }
220            "ReturnsAverageLoss" => {
221                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
222                self.deregister_statistic(Arc::new(stat));
223            }
224            "ReturnsAverageWin" => {
225                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
226                self.deregister_statistic(Arc::new(stat));
227            }
228            "SharpeRatio" => {
229                let stat = statistic.extract::<SharpeRatio>(py)?;
230                self.deregister_statistic(Arc::new(stat));
231            }
232            "SortinoRatio" => {
233                let stat = statistic.extract::<SortinoRatio>(py)?;
234                self.deregister_statistic(Arc::new(stat));
235            }
236            "ProfitFactor" => {
237                let stat = statistic.extract::<ProfitFactor>(py)?;
238                self.deregister_statistic(Arc::new(stat));
239            }
240            "RiskReturnRatio" => {
241                let stat = statistic.extract::<RiskReturnRatio>(py)?;
242                self.deregister_statistic(Arc::new(stat));
243            }
244            "LongRatio" => {
245                let stat = statistic.extract::<LongRatio>(py)?;
246                self.deregister_statistic(Arc::new(stat));
247            }
248            _ => {
249                return Err(PyValueError::new_err(format!(
250                    "Unknown statistic type: {}",
251                    type_name
252                )));
253            }
254        }
255
256        Ok(())
257    }
258
259    #[pyo3(name = "deregister_statistics")]
260    fn py_deregister_statistics(&mut self) {
261        self.deregister_statistics();
262    }
263
264    #[pyo3(name = "add_positions")]
265    fn py_add_positions(&mut self, py: Python, positions: Vec<Py<PyAny>>) -> PyResult<()> {
266        // Extract Position objects from Cython wrappers
267        let positions: Vec<Position> = positions
268            .iter()
269            .map(|p| {
270                // Try to get the underlying Rust Position
271                // For now, we'll need to handle Cython Position by accessing its _mem field
272                p.getattr(py, "_mem")?.extract::<Position>(py)
273            })
274            .collect::<PyResult<Vec<Position>>>()?;
275
276        self.add_positions(&positions);
277        Ok(())
278    }
279
280    #[pyo3(name = "add_trade")]
281    fn py_add_trade(&mut self, position_id: &PositionId, realized_pnl: &Money) {
282        self.add_trade(position_id, realized_pnl);
283    }
284
285    // Note: calculate_statistics is not exposed to Python because it requires
286    // complex conversions of Account and dict types. Use the Python analyzer.py wrapper instead.
287
288    #[pyo3(name = "statistic")]
289    fn py_statistic(&self, name: &str) -> Option<String> {
290        self.statistic(name).map(|s| s.name())
291    }
292
293    #[pyo3(name = "returns")]
294    fn py_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
295        // Convert BTreeMap<UnixNanos, f64> to Python dict
296        let dict = pyo3::types::PyDict::new(py);
297        for (timestamp, value) in self.returns() {
298            dict.set_item(timestamp.as_u64(), value)?;
299        }
300        Ok(dict.into())
301    }
302
303    #[pyo3(name = "realized_pnls")]
304    fn py_realized_pnls(&self, py: Python, currency: Option<&Currency>) -> PyResult<Py<PyAny>> {
305        match self.realized_pnls(currency) {
306            Some(pnls) => {
307                // Convert Vec<(PositionId, f64)> to Python list of tuples or dict
308                let dict = pyo3::types::PyDict::new(py);
309                for (position_id, pnl) in pnls {
310                    dict.set_item(position_id.to_string(), pnl)?;
311                }
312                Ok(dict.into())
313            }
314            None => Ok(py.None()),
315        }
316    }
317
318    #[pyo3(name = "total_pnl")]
319    fn py_total_pnl(
320        &self,
321        currency: Option<&Currency>,
322        unrealized_pnl: Option<&Money>,
323    ) -> PyResult<f64> {
324        self.total_pnl(currency, unrealized_pnl)
325            .map_err(to_pyvalue_err)
326    }
327
328    #[pyo3(name = "total_pnl_percentage")]
329    fn py_total_pnl_percentage(
330        &self,
331        currency: Option<&Currency>,
332        unrealized_pnl: Option<&Money>,
333    ) -> PyResult<f64> {
334        self.total_pnl_percentage(currency, unrealized_pnl)
335            .map_err(to_pyvalue_err)
336    }
337
338    #[pyo3(name = "get_stats_pnls_formatted")]
339    fn py_get_stats_pnls_formatted(
340        &self,
341        currency: Option<&Currency>,
342        unrealized_pnl: Option<&Money>,
343    ) -> PyResult<Vec<String>> {
344        self.get_stats_pnls_formatted(currency, unrealized_pnl)
345            .map_err(|e| PyValueError::new_err(e.to_string()))
346    }
347
348    #[pyo3(name = "get_stats_returns_formatted")]
349    fn py_get_stats_returns_formatted(&self) -> Vec<String> {
350        self.get_stats_returns_formatted()
351    }
352
353    #[pyo3(name = "get_stats_general_formatted")]
354    fn py_get_stats_general_formatted(&self) -> Vec<String> {
355        self.get_stats_general_formatted()
356    }
357}