nautilus_indicators/average/
rma.rs1use std::fmt::Display;
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::indicator::{Indicator, MovingAverage};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct WilderMovingAverage {
32 pub period: usize,
33 pub price_type: PriceType,
34 pub alpha: f64,
35 pub value: f64,
36 pub count: usize,
37 pub initialized: bool,
38 has_inputs: bool,
39}
40
41impl Display for WilderMovingAverage {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}({})", self.name(), self.period,)
44 }
45}
46
47impl Indicator for WilderMovingAverage {
48 fn name(&self) -> String {
49 stringify!(WilderMovingAverage).to_string()
50 }
51
52 fn has_inputs(&self) -> bool {
53 self.has_inputs
54 }
55
56 fn initialized(&self) -> bool {
57 self.initialized
58 }
59
60 fn handle_quote(&mut self, quote: &QuoteTick) {
61 self.update_raw(quote.extract_price(self.price_type).into());
62 }
63
64 fn handle_trade(&mut self, trade: &TradeTick) {
65 self.update_raw((&trade.price).into());
66 }
67
68 fn handle_bar(&mut self, bar: &Bar) {
69 self.update_raw((&bar.close).into());
70 }
71
72 fn reset(&mut self) {
73 self.value = 0.0;
74 self.count = 0;
75 self.has_inputs = false;
76 self.initialized = false;
77 }
78}
79
80impl WilderMovingAverage {
81 #[must_use]
83 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
84 Self {
88 period,
89 price_type: price_type.unwrap_or(PriceType::Last),
90 alpha: 1.0 / (period as f64),
91 value: 0.0,
92 count: 0,
93 has_inputs: false,
94 initialized: false,
95 }
96 }
97}
98
99impl MovingAverage for WilderMovingAverage {
100 fn value(&self) -> f64 {
101 self.value
102 }
103
104 fn count(&self) -> usize {
105 self.count
106 }
107
108 fn update_raw(&mut self, value: f64) {
109 if !self.has_inputs {
110 self.has_inputs = true;
111 self.value = value;
112 }
113
114 self.value = self.alpha.mul_add(value, (1.0 - self.alpha) * self.value);
115 self.count += 1;
116
117 if !self.initialized && self.count >= self.period {
119 self.initialized = true;
120 }
121 }
122}
123
124#[cfg(test)]
128mod tests {
129 use nautilus_model::{
130 data::{Bar, QuoteTick, TradeTick},
131 enums::PriceType,
132 };
133 use rstest::rstest;
134
135 use crate::{
136 average::rma::WilderMovingAverage,
137 indicator::{Indicator, MovingAverage},
138 stubs::*,
139 };
140
141 #[rstest]
142 fn test_rma_initialized(indicator_rma_10: WilderMovingAverage) {
143 let rma = indicator_rma_10;
144 let display_str = format!("{rma}");
145 assert_eq!(display_str, "WilderMovingAverage(10)");
146 assert_eq!(rma.period, 10);
147 assert_eq!(rma.price_type, PriceType::Mid);
148 assert_eq!(rma.alpha, 0.1);
149 assert!(!rma.initialized);
150 }
151
152 #[rstest]
153 fn test_one_value_input(indicator_rma_10: WilderMovingAverage) {
154 let mut rma = indicator_rma_10;
155 rma.update_raw(1.0);
156 assert_eq!(rma.count, 1);
157 assert_eq!(rma.value, 1.0);
158 }
159
160 #[rstest]
161 fn test_rma_update_raw(indicator_rma_10: WilderMovingAverage) {
162 let mut rma = indicator_rma_10;
163 rma.update_raw(1.0);
164 rma.update_raw(2.0);
165 rma.update_raw(3.0);
166 rma.update_raw(4.0);
167 rma.update_raw(5.0);
168 rma.update_raw(6.0);
169 rma.update_raw(7.0);
170 rma.update_raw(8.0);
171 rma.update_raw(9.0);
172 rma.update_raw(10.0);
173
174 assert!(rma.has_inputs());
175 assert!(rma.initialized());
176 assert_eq!(rma.count, 10);
177 assert_eq!(rma.value, 4.486_784_401);
178 }
179
180 #[rstest]
181 fn test_reset(indicator_rma_10: WilderMovingAverage) {
182 let mut rma = indicator_rma_10;
183 rma.update_raw(1.0);
184 assert_eq!(rma.count, 1);
185 rma.reset();
186 assert_eq!(rma.count, 0);
187 assert_eq!(rma.value, 0.0);
188 assert!(!rma.initialized);
189 }
190
191 #[rstest]
192 fn test_handle_quote_tick_single(indicator_rma_10: WilderMovingAverage, stub_quote: QuoteTick) {
193 let mut rma = indicator_rma_10;
194 rma.handle_quote(&stub_quote);
195 assert!(rma.has_inputs());
196 assert_eq!(rma.value, 1501.0);
197 }
198
199 #[rstest]
200 fn test_handle_quote_tick_multi(mut indicator_rma_10: WilderMovingAverage) {
201 let tick1 = stub_quote("1500.0", "1502.0");
202 let tick2 = stub_quote("1502.0", "1504.0");
203
204 indicator_rma_10.handle_quote(&tick1);
205 indicator_rma_10.handle_quote(&tick2);
206 assert_eq!(indicator_rma_10.count, 2);
207 assert_eq!(indicator_rma_10.value, 1_501.2);
208 }
209
210 #[rstest]
211 fn test_handle_trade_tick(indicator_rma_10: WilderMovingAverage, stub_trade: TradeTick) {
212 let mut rma = indicator_rma_10;
213 rma.handle_trade(&stub_trade);
214 assert!(rma.has_inputs());
215 assert_eq!(rma.value, 1500.0);
216 }
217
218 #[rstest]
219 fn handle_handle_bar(
220 mut indicator_rma_10: WilderMovingAverage,
221 bar_ethusdt_binance_minute_bid: Bar,
222 ) {
223 indicator_rma_10.handle_bar(&bar_ethusdt_binance_minute_bid);
224 assert!(indicator_rma_10.has_inputs);
225 assert!(!indicator_rma_10.initialized);
226 assert_eq!(indicator_rma_10.value, 1522.0);
227 }
228}