nautilus_indicators/average/
wma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_core::correctness::{FAILED, check_predicate_true};
19use nautilus_model::{
20    data::{Bar, QuoteTick, TradeTick},
21    enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26/// An indicator which calculates a weighted moving average across a rolling window.
27#[repr(C)]
28#[derive(Debug)]
29#[cfg_attr(
30    feature = "python",
31    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
32)]
33pub struct WeightedMovingAverage {
34    /// The rolling window period for the indicator (> 0).
35    pub period: usize,
36    /// The weights for the moving average calculation
37    pub weights: Vec<f64>,
38    /// Price type
39    pub price_type: PriceType,
40    /// The last indicator value.
41    pub value: f64,
42    /// Whether the indicator is initialized.
43    pub initialized: bool,
44    /// Inputs
45    pub inputs: Vec<f64>,
46    has_inputs: bool,
47}
48
49impl Display for WeightedMovingAverage {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
52    }
53}
54
55impl WeightedMovingAverage {
56    /// Creates a new [`WeightedMovingAverage`] instance.
57    #[must_use]
58    pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
59        Self::new_checked(period, weights, price_type).expect(FAILED)
60    }
61
62    pub fn new_checked(
63        period: usize,
64        weights: Vec<f64>,
65        price_type: Option<PriceType>,
66    ) -> anyhow::Result<Self> {
67        check_predicate_true(
68            period == weights.len(),
69            "`period` must be equal to `weights` length",
70        )?;
71
72        Ok(Self {
73            period,
74            weights,
75            price_type: price_type.unwrap_or(PriceType::Last),
76            value: 0.0,
77            inputs: Vec::with_capacity(period),
78            initialized: false,
79            has_inputs: false,
80        })
81    }
82
83    fn weighted_average(&self) -> f64 {
84        let mut sum = 0.0;
85        let mut weight_sum = 0.0;
86        let reverse_weights: Vec<f64> = self.weights.iter().copied().rev().collect();
87        for (index, input) in self.inputs.iter().rev().enumerate() {
88            let weight = reverse_weights.get(index).unwrap();
89            sum += input * weight;
90            weight_sum += weight;
91        }
92        sum / weight_sum
93    }
94}
95
96impl Indicator for WeightedMovingAverage {
97    fn name(&self) -> String {
98        stringify!(WeightedMovingAverage).to_string()
99    }
100
101    fn has_inputs(&self) -> bool {
102        self.has_inputs
103    }
104    fn initialized(&self) -> bool {
105        self.initialized
106    }
107
108    fn handle_quote(&mut self, quote: &QuoteTick) {
109        self.update_raw(quote.extract_price(self.price_type).into());
110    }
111
112    fn handle_trade(&mut self, trade: &TradeTick) {
113        self.update_raw((&trade.price).into());
114    }
115
116    fn handle_bar(&mut self, bar: &Bar) {
117        self.update_raw((&bar.close).into());
118    }
119
120    fn reset(&mut self) {
121        self.value = 0.0;
122        self.has_inputs = false;
123        self.initialized = false;
124        self.inputs.clear();
125    }
126}
127
128impl MovingAverage for WeightedMovingAverage {
129    fn value(&self) -> f64 {
130        self.value
131    }
132
133    fn count(&self) -> usize {
134        self.inputs.len()
135    }
136    fn update_raw(&mut self, value: f64) {
137        if !self.has_inputs {
138            self.has_inputs = true;
139            self.inputs.push(value);
140            self.value = value;
141            return;
142        }
143        if self.inputs.len() == self.period {
144            self.inputs.remove(0);
145        }
146        self.inputs.push(value);
147        self.value = self.weighted_average();
148        if !self.initialized && self.count() >= self.period {
149            self.initialized = true;
150        }
151    }
152}
153
154////////////////////////////////////////////////////////////////////////////////
155// Tests
156////////////////////////////////////////////////////////////////////////////////
157#[cfg(test)]
158mod tests {
159    use rstest::rstest;
160
161    use crate::{
162        average::wma::WeightedMovingAverage,
163        indicator::{Indicator, MovingAverage},
164        stubs::*,
165    };
166
167    #[rstest]
168    fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
169        let display_str = format!("{indicator_wma_10}");
170        assert_eq!(
171            display_str,
172            "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
173        );
174        assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
175        assert!(!indicator_wma_10.has_inputs());
176        assert!(!indicator_wma_10.initialized());
177    }
178
179    #[rstest]
180    #[should_panic]
181    fn test_different_weights_len_and_period_error() {
182        let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
183    }
184
185    #[rstest]
186    fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
187        indicator_wma_10.update_raw(1.0);
188        assert_eq!(indicator_wma_10.value, 1.0);
189    }
190
191    #[rstest]
192    fn test_value_with_two_inputs_equal_weights() {
193        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
194        wma.update_raw(1.0);
195        wma.update_raw(2.0);
196        assert_eq!(wma.value, 1.5);
197    }
198
199    #[rstest]
200    fn test_value_with_four_inputs_equal_weights() {
201        let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
202        wma.update_raw(1.0);
203        wma.update_raw(2.0);
204        wma.update_raw(3.0);
205        wma.update_raw(4.0);
206        assert_eq!(wma.value, 2.5);
207    }
208
209    #[rstest]
210    fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
211        indicator_wma_10.update_raw(1.0);
212        indicator_wma_10.update_raw(2.0);
213        let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
214        assert_eq!(indicator_wma_10.value, result);
215    }
216
217    #[rstest]
218    fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
219        indicator_wma_10.update_raw(1.0);
220        indicator_wma_10.update_raw(2.0);
221        indicator_wma_10.update_raw(3.0);
222        let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
223        assert_eq!(indicator_wma_10.value, result);
224    }
225
226    #[rstest]
227    fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
228        for i in 1..11 {
229            indicator_wma_10.update_raw(f64::from(i));
230        }
231        assert_eq!(indicator_wma_10.value, 7.0);
232    }
233
234    #[rstest]
235    fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
236        for i in 1..=11 {
237            indicator_wma_10.update_raw(f64::from(i));
238        }
239        assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
240    }
241
242    #[rstest]
243    fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
244        indicator_wma_10.update_raw(1.0);
245        indicator_wma_10.update_raw(2.0);
246        indicator_wma_10.reset();
247        assert_eq!(indicator_wma_10.value, 0.0);
248        assert_eq!(indicator_wma_10.count(), 0);
249        assert!(!indicator_wma_10.has_inputs);
250        assert!(!indicator_wma_10.initialized);
251    }
252}