nautilus_analysis/statistics/
max_drawdown.rs1use std::collections::BTreeMap;
19
20use nautilus_core::UnixNanos;
21
22use crate::statistic::PortfolioStatistic;
23
24#[repr(C)]
32#[derive(Debug, Clone, Default)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.analysis")
36)]
37pub struct MaxDrawdown {}
38
39impl MaxDrawdown {
40 #[must_use]
42 pub fn new() -> Self {
43 Self {}
44 }
45}
46
47impl PortfolioStatistic for MaxDrawdown {
48 type Item = f64;
49
50 fn name(&self) -> String {
51 "Max Drawdown".to_string()
52 }
53
54 fn calculate_from_returns(&self, returns: &BTreeMap<UnixNanos, f64>) -> Option<Self::Item> {
55 if returns.is_empty() {
56 return Some(0.0);
57 }
58
59 let mut cumulative = 1.0;
61 let mut running_max = 1.0;
62 let mut max_drawdown = 0.0;
63
64 for &ret in returns.values() {
65 cumulative *= 1.0 + ret;
66
67 if cumulative > running_max {
69 running_max = cumulative;
70 }
71
72 let drawdown = (running_max - cumulative) / running_max;
74
75 if drawdown > max_drawdown {
77 max_drawdown = drawdown;
78 }
79 }
80
81 Some(-max_drawdown)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use rstest::rstest;
89
90 use super::*;
91
92 fn create_returns(values: Vec<f64>) -> BTreeMap<UnixNanos, f64> {
93 values
94 .into_iter()
95 .enumerate()
96 .map(|(i, v)| (UnixNanos::from(i as u64), v))
97 .collect()
98 }
99
100 #[rstest]
101 fn test_name() {
102 let stat = MaxDrawdown::new();
103 assert_eq!(stat.name(), "Max Drawdown");
104 }
105
106 #[rstest]
107 fn test_empty_returns() {
108 let stat = MaxDrawdown::new();
109 let returns = BTreeMap::new();
110 let result = stat.calculate_from_returns(&returns);
111 assert_eq!(result, Some(0.0));
112 }
113
114 #[rstest]
115 fn test_no_drawdown() {
116 let stat = MaxDrawdown::new();
117 let returns = create_returns(vec![0.01, 0.02, 0.01, 0.015]);
119 let result = stat.calculate_from_returns(&returns).unwrap();
120 assert_eq!(result, 0.0);
121 }
122
123 #[rstest]
124 fn test_simple_drawdown() {
125 let stat = MaxDrawdown::new();
126 let returns = create_returns(vec![0.10, -0.10]);
129 let result = stat.calculate_from_returns(&returns).unwrap();
130
131 assert!((result + 0.10).abs() < 0.01);
133 }
134
135 #[rstest]
136 fn test_multiple_drawdowns() {
137 let stat = MaxDrawdown::new();
138 let returns = create_returns(vec![0.10, -0.10, 0.50, -0.20, 0.10]);
142 let result = stat.calculate_from_returns(&returns).unwrap();
143
144 assert!((result + 0.20).abs() < 0.01);
146 }
147
148 #[rstest]
149 fn test_initial_loss() {
150 let stat = MaxDrawdown::new();
151 let returns = create_returns(vec![-0.40, -0.10]);
153 let result = stat.calculate_from_returns(&returns).unwrap();
154
155 assert!((result + 0.46).abs() < 0.01);
158 }
159}