nautilus_indicators/average/
sma.rs1use std::fmt::Display;
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::indicator::{Indicator, MovingAverage};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct SimpleMovingAverage {
32 pub period: usize,
33 pub price_type: PriceType,
34 pub value: f64,
35 pub count: usize,
36 pub inputs: Vec<f64>,
37 pub initialized: bool,
38}
39
40impl Display for SimpleMovingAverage {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "{}({})", self.name(), self.period,)
43 }
44}
45
46impl Indicator for SimpleMovingAverage {
47 fn name(&self) -> String {
48 stringify!(SimpleMovingAverage).to_string()
49 }
50
51 fn has_inputs(&self) -> bool {
52 !self.inputs.is_empty()
53 }
54
55 fn initialized(&self) -> bool {
56 self.initialized
57 }
58
59 fn handle_quote(&mut self, quote: &QuoteTick) {
60 self.update_raw(quote.extract_price(self.price_type).into());
61 }
62
63 fn handle_trade(&mut self, trade: &TradeTick) {
64 self.update_raw((&trade.price).into());
65 }
66
67 fn handle_bar(&mut self, bar: &Bar) {
68 self.update_raw((&bar.close).into());
69 }
70
71 fn reset(&mut self) {
72 self.value = 0.0;
73 self.count = 0;
74 self.inputs.clear();
75 self.initialized = false;
76 }
77}
78
79impl SimpleMovingAverage {
80 #[must_use]
82 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
83 Self {
84 period,
85 price_type: price_type.unwrap_or(PriceType::Last),
86 value: 0.0,
87 count: 0,
88 inputs: Vec::with_capacity(period),
89 initialized: false,
90 }
91 }
92}
93
94impl MovingAverage for SimpleMovingAverage {
95 fn value(&self) -> f64 {
96 self.value
97 }
98
99 fn count(&self) -> usize {
100 self.count
101 }
102 fn update_raw(&mut self, value: f64) {
103 if self.inputs.len() == self.period {
104 self.inputs.remove(0);
105 self.count -= 1;
106 }
107 self.inputs.push(value);
108 self.count += 1;
109 let sum = self.inputs.iter().sum::<f64>();
110 self.value = sum / self.count as f64;
111
112 if !self.initialized && self.count >= self.period {
113 self.initialized = true;
114 }
115 }
116}
117
118#[cfg(test)]
122mod tests {
123 use nautilus_model::{
124 data::{QuoteTick, TradeTick},
125 enums::PriceType,
126 };
127 use rstest::rstest;
128
129 use crate::{
130 average::sma::SimpleMovingAverage,
131 indicator::{Indicator, MovingAverage},
132 stubs::*,
133 };
134
135 #[rstest]
136 fn test_sma_initialized(indicator_sma_10: SimpleMovingAverage) {
137 let display_str = format!("{indicator_sma_10}");
138 assert_eq!(display_str, "SimpleMovingAverage(10)");
139 assert_eq!(indicator_sma_10.period, 10);
140 assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
141 assert_eq!(indicator_sma_10.value, 0.0);
142 assert_eq!(indicator_sma_10.count, 0);
143 }
144
145 #[rstest]
146 fn test_sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
147 let mut sma = indicator_sma_10;
148 sma.update_raw(1.0);
149 sma.update_raw(2.0);
150 sma.update_raw(3.0);
151 sma.update_raw(4.0);
152 sma.update_raw(5.0);
153 sma.update_raw(6.0);
154 sma.update_raw(7.0);
155 sma.update_raw(8.0);
156 sma.update_raw(9.0);
157 sma.update_raw(10.0);
158
159 assert!(sma.has_inputs());
160 assert!(sma.initialized());
161 assert_eq!(sma.count, 10);
162 assert_eq!(sma.value, 5.5);
163 }
164
165 #[rstest]
166 fn test_reset(indicator_sma_10: SimpleMovingAverage) {
167 let mut sma = indicator_sma_10;
168 sma.update_raw(1.0);
169 assert_eq!(sma.count, 1);
170 sma.reset();
171 assert_eq!(sma.count, 0);
172 assert_eq!(sma.value, 0.0);
173 assert!(!sma.initialized);
174 }
175
176 #[rstest]
177 fn test_handle_quote_tick_single(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
178 let mut sma = indicator_sma_10;
179 sma.handle_quote(&stub_quote);
180 assert_eq!(sma.count, 1);
181 assert_eq!(sma.value, 1501.0);
182 }
183
184 #[rstest]
185 fn test_handle_quote_tick_multi(indicator_sma_10: SimpleMovingAverage) {
186 let mut sma = indicator_sma_10;
187 let tick1 = stub_quote("1500.0", "1502.0");
188 let tick2 = stub_quote("1502.0", "1504.0");
189
190 sma.handle_quote(&tick1);
191 sma.handle_quote(&tick2);
192 assert_eq!(sma.count, 2);
193 assert_eq!(sma.value, 1502.0);
194 }
195
196 #[rstest]
197 fn test_handle_trade_tick(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
198 let mut sma = indicator_sma_10;
199 sma.handle_trade(&stub_trade);
200 assert_eq!(sma.count, 1);
201 assert_eq!(sma.value, 1500.0);
202 }
203}