nautilus_indicators/average/
sma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_model::{
19    data::{Bar, QuoteTick, TradeTick},
20    enums::PriceType,
21};
22
23use crate::indicator::{Indicator, MovingAverage};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct SimpleMovingAverage {
32    pub period: usize,
33    pub price_type: PriceType,
34    pub value: f64,
35    pub count: usize,
36    pub inputs: Vec<f64>,
37    pub initialized: bool,
38}
39
40impl Display for SimpleMovingAverage {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "{}({})", self.name(), self.period,)
43    }
44}
45
46impl Indicator for SimpleMovingAverage {
47    fn name(&self) -> String {
48        stringify!(SimpleMovingAverage).to_string()
49    }
50
51    fn has_inputs(&self) -> bool {
52        !self.inputs.is_empty()
53    }
54
55    fn initialized(&self) -> bool {
56        self.initialized
57    }
58
59    fn handle_quote(&mut self, quote: &QuoteTick) {
60        self.update_raw(quote.extract_price(self.price_type).into());
61    }
62
63    fn handle_trade(&mut self, trade: &TradeTick) {
64        self.update_raw((&trade.price).into());
65    }
66
67    fn handle_bar(&mut self, bar: &Bar) {
68        self.update_raw((&bar.close).into());
69    }
70
71    fn reset(&mut self) {
72        self.value = 0.0;
73        self.count = 0;
74        self.inputs.clear();
75        self.initialized = false;
76    }
77}
78
79impl SimpleMovingAverage {
80    /// Creates a new [`SimpleMovingAverage`] instance.
81    #[must_use]
82    pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
83        Self {
84            period,
85            price_type: price_type.unwrap_or(PriceType::Last),
86            value: 0.0,
87            count: 0,
88            inputs: Vec::with_capacity(period),
89            initialized: false,
90        }
91    }
92}
93
94impl MovingAverage for SimpleMovingAverage {
95    fn value(&self) -> f64 {
96        self.value
97    }
98
99    fn count(&self) -> usize {
100        self.count
101    }
102    fn update_raw(&mut self, value: f64) {
103        if self.inputs.len() == self.period {
104            self.inputs.remove(0);
105            self.count -= 1;
106        }
107        self.inputs.push(value);
108        self.count += 1;
109        let sum = self.inputs.iter().sum::<f64>();
110        self.value = sum / self.count as f64;
111
112        if !self.initialized && self.count >= self.period {
113            self.initialized = true;
114        }
115    }
116}
117
118////////////////////////////////////////////////////////////////////////////////
119// Tests
120////////////////////////////////////////////////////////////////////////////////
121#[cfg(test)]
122mod tests {
123    use nautilus_model::{
124        data::{QuoteTick, TradeTick},
125        enums::PriceType,
126    };
127    use rstest::rstest;
128
129    use crate::{
130        average::sma::SimpleMovingAverage,
131        indicator::{Indicator, MovingAverage},
132        stubs::*,
133    };
134
135    #[rstest]
136    fn test_sma_initialized(indicator_sma_10: SimpleMovingAverage) {
137        let display_str = format!("{indicator_sma_10}");
138        assert_eq!(display_str, "SimpleMovingAverage(10)");
139        assert_eq!(indicator_sma_10.period, 10);
140        assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
141        assert_eq!(indicator_sma_10.value, 0.0);
142        assert_eq!(indicator_sma_10.count, 0);
143    }
144
145    #[rstest]
146    fn test_sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
147        let mut sma = indicator_sma_10;
148        sma.update_raw(1.0);
149        sma.update_raw(2.0);
150        sma.update_raw(3.0);
151        sma.update_raw(4.0);
152        sma.update_raw(5.0);
153        sma.update_raw(6.0);
154        sma.update_raw(7.0);
155        sma.update_raw(8.0);
156        sma.update_raw(9.0);
157        sma.update_raw(10.0);
158
159        assert!(sma.has_inputs());
160        assert!(sma.initialized());
161        assert_eq!(sma.count, 10);
162        assert_eq!(sma.value, 5.5);
163    }
164
165    #[rstest]
166    fn test_reset(indicator_sma_10: SimpleMovingAverage) {
167        let mut sma = indicator_sma_10;
168        sma.update_raw(1.0);
169        assert_eq!(sma.count, 1);
170        sma.reset();
171        assert_eq!(sma.count, 0);
172        assert_eq!(sma.value, 0.0);
173        assert!(!sma.initialized);
174    }
175
176    #[rstest]
177    fn test_handle_quote_tick_single(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
178        let mut sma = indicator_sma_10;
179        sma.handle_quote(&stub_quote);
180        assert_eq!(sma.count, 1);
181        assert_eq!(sma.value, 1501.0);
182    }
183
184    #[rstest]
185    fn test_handle_quote_tick_multi(indicator_sma_10: SimpleMovingAverage) {
186        let mut sma = indicator_sma_10;
187        let tick1 = stub_quote("1500.0", "1502.0");
188        let tick2 = stub_quote("1502.0", "1504.0");
189
190        sma.handle_quote(&tick1);
191        sma.handle_quote(&tick2);
192        assert_eq!(sma.count, 2);
193        assert_eq!(sma.value, 1502.0);
194    }
195
196    #[rstest]
197    fn test_handle_trade_tick(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
198        let mut sma = indicator_sma_10;
199        sma.handle_trade(&stub_trade);
200        assert_eq!(sma.count, 1);
201        assert_eq!(sma.value, 1500.0);
202    }
203}