nautilus_indicators/average/
ema.rs1use std::fmt::Display;
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::indicator::{Indicator, MovingAverage};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct ExponentialMovingAverage {
32 pub period: usize,
33 pub price_type: PriceType,
34 pub alpha: f64,
35 pub value: f64,
36 pub count: usize,
37 pub initialized: bool,
38 has_inputs: bool,
39}
40
41impl Display for ExponentialMovingAverage {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}({})", self.name(), self.period,)
44 }
45}
46
47impl Indicator for ExponentialMovingAverage {
48 fn name(&self) -> String {
49 stringify!(ExponentialMovingAverage).to_string()
50 }
51
52 fn has_inputs(&self) -> bool {
53 self.has_inputs
54 }
55
56 fn initialized(&self) -> bool {
57 self.initialized
58 }
59
60 fn handle_quote(&mut self, quote: &QuoteTick) {
61 self.update_raw(quote.extract_price(self.price_type).into());
62 }
63
64 fn handle_trade(&mut self, trade: &TradeTick) {
65 self.update_raw((&trade.price).into());
66 }
67
68 fn handle_bar(&mut self, bar: &Bar) {
69 self.update_raw((&bar.close).into());
70 }
71
72 fn reset(&mut self) {
73 self.value = 0.0;
74 self.count = 0;
75 self.has_inputs = false;
76 self.initialized = false;
77 }
78}
79
80impl ExponentialMovingAverage {
81 #[must_use]
83 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
84 Self {
85 period,
86 price_type: price_type.unwrap_or(PriceType::Last),
87 alpha: 2.0 / (period as f64 + 1.0),
88 value: 0.0,
89 count: 0,
90 has_inputs: false,
91 initialized: false,
92 }
93 }
94}
95
96impl MovingAverage for ExponentialMovingAverage {
97 fn value(&self) -> f64 {
98 self.value
99 }
100
101 fn count(&self) -> usize {
102 self.count
103 }
104 fn update_raw(&mut self, value: f64) {
105 if !self.has_inputs {
106 self.has_inputs = true;
107 self.value = value;
108 }
109
110 self.value = self.alpha.mul_add(value, (1.0 - self.alpha) * self.value);
111 self.count += 1;
112
113 if !self.initialized && self.count >= self.period {
115 self.initialized = true;
116 }
117 }
118}
119
120#[cfg(test)]
124mod tests {
125 use nautilus_model::{
126 data::{Bar, QuoteTick, TradeTick},
127 enums::PriceType,
128 };
129 use rstest::rstest;
130
131 use crate::{
132 average::ema::ExponentialMovingAverage,
133 indicator::{Indicator, MovingAverage},
134 stubs::*,
135 };
136
137 #[rstest]
138 fn test_ema_initialized(indicator_ema_10: ExponentialMovingAverage) {
139 let ema = indicator_ema_10;
140 let display_str = format!("{ema}");
141 assert_eq!(display_str, "ExponentialMovingAverage(10)");
142 assert_eq!(ema.period, 10);
143 assert_eq!(ema.price_type, PriceType::Mid);
144 assert_eq!(ema.alpha, 0.181_818_181_818_181_82);
145 assert!(!ema.initialized);
146 }
147
148 #[rstest]
149 fn test_one_value_input(indicator_ema_10: ExponentialMovingAverage) {
150 let mut ema = indicator_ema_10;
151 ema.update_raw(1.0);
152 assert_eq!(ema.count, 1);
153 assert_eq!(ema.value, 1.0);
154 }
155
156 #[rstest]
157 fn test_ema_update_raw(indicator_ema_10: ExponentialMovingAverage) {
158 let mut ema = indicator_ema_10;
159 ema.update_raw(1.0);
160 ema.update_raw(2.0);
161 ema.update_raw(3.0);
162 ema.update_raw(4.0);
163 ema.update_raw(5.0);
164 ema.update_raw(6.0);
165 ema.update_raw(7.0);
166 ema.update_raw(8.0);
167 ema.update_raw(9.0);
168 ema.update_raw(10.0);
169
170 assert!(ema.has_inputs());
171 assert!(ema.initialized());
172 assert_eq!(ema.count, 10);
173 assert_eq!(ema.value, 6.239_368_480_121_215_5);
174 }
175
176 #[rstest]
177 fn test_reset(indicator_ema_10: ExponentialMovingAverage) {
178 let mut ema = indicator_ema_10;
179 ema.update_raw(1.0);
180 assert_eq!(ema.count, 1);
181 ema.reset();
182 assert_eq!(ema.count, 0);
183 assert_eq!(ema.value, 0.0);
184 assert!(!ema.initialized);
185 }
186
187 #[rstest]
188 fn test_handle_quote_tick_single(
189 indicator_ema_10: ExponentialMovingAverage,
190 stub_quote: QuoteTick,
191 ) {
192 let mut ema = indicator_ema_10;
193 ema.handle_quote(&stub_quote);
194 assert!(ema.has_inputs());
195 assert_eq!(ema.value, 1501.0);
196 }
197
198 #[rstest]
199 fn test_handle_quote_tick_multi(mut indicator_ema_10: ExponentialMovingAverage) {
200 let tick1 = stub_quote("1500.0", "1502.0");
201 let tick2 = stub_quote("1502.0", "1504.0");
202
203 indicator_ema_10.handle_quote(&tick1);
204 indicator_ema_10.handle_quote(&tick2);
205 assert_eq!(indicator_ema_10.count, 2);
206 assert_eq!(indicator_ema_10.value, 1_501.363_636_363_636_3);
207 }
208
209 #[rstest]
210 fn test_handle_trade_tick(indicator_ema_10: ExponentialMovingAverage, stub_trade: TradeTick) {
211 let mut ema = indicator_ema_10;
212 ema.handle_trade(&stub_trade);
213 assert!(ema.has_inputs());
214 assert_eq!(ema.value, 1500.0);
215 }
216
217 #[rstest]
218 fn handle_handle_bar(
219 mut indicator_ema_10: ExponentialMovingAverage,
220 bar_ethusdt_binance_minute_bid: Bar,
221 ) {
222 indicator_ema_10.handle_bar(&bar_ethusdt_binance_minute_bid);
223 assert!(indicator_ema_10.has_inputs);
224 assert!(!indicator_ema_10.initialized);
225 assert_eq!(indicator_ema_10.value, 1522.0);
226 }
227}