nautilus_indicators/average/
wma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_core::correctness::{FAILED, check_predicate_true};
20use nautilus_model::{
21    data::{Bar, QuoteTick, TradeTick},
22    enums::PriceType,
23};
24
25use crate::indicator::{Indicator, MovingAverage};
26
27const MAX_PERIOD: usize = 8_192;
28
29/// An indicator which calculates a weighted moving average across a rolling window.
30#[repr(C)]
31#[derive(Debug)]
32#[cfg_attr(
33    feature = "python",
34    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
35)]
36pub struct WeightedMovingAverage {
37    /// The rolling window period for the indicator (> 0).
38    pub period: usize,
39    /// The weights for the moving average calculation
40    pub weights: Vec<f64>,
41    /// Price type
42    pub price_type: PriceType,
43    /// The last indicator value.
44    pub value: f64,
45    /// Whether the indicator is initialized.
46    pub initialized: bool,
47    /// Inputs
48    pub inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
49}
50
51impl Display for WeightedMovingAverage {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
54    }
55}
56
57impl WeightedMovingAverage {
58    /// Creates a new [`WeightedMovingAverage`] instance.
59    ///
60    /// # Panics
61    ///
62    /// This function panics if:
63    /// - `period` is zero.
64    /// - `weights.len()` does not equal `period`.
65    /// - `weights` sum is effectively zero.
66    #[must_use]
67    pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
68        Self::new_checked(period, weights, price_type).expect(FAILED)
69    }
70
71    /// Creates a new [`WeightedMovingAverage`] instance with the given period and weights.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if **any** of the validation rules fails:
76    /// - `period` must be **positive**.
77    /// - `weights` must be **exactly** `period` elements long.
78    /// - `weights` must contain at least one non-zero value (∑wᵢ > ε).
79    pub fn new_checked(
80        period: usize,
81        weights: Vec<f64>,
82        price_type: Option<PriceType>,
83    ) -> anyhow::Result<Self> {
84        const EPS: f64 = f64::EPSILON;
85
86        check_predicate_true(period > 0, "`period` must be positive")?;
87
88        check_predicate_true(
89            period == weights.len(),
90            "`period` must equal `weights.len()`",
91        )?;
92
93        let weight_sum: f64 = weights.iter().copied().sum();
94        check_predicate_true(
95            weight_sum > EPS,
96            "`weights` sum must be positive and > f64::EPSILON",
97        )?;
98
99        Ok(Self {
100            period,
101            weights,
102            price_type: price_type.unwrap_or(PriceType::Last),
103            value: 0.0,
104            inputs: ArrayDeque::new(),
105            initialized: false,
106        })
107    }
108
109    fn weighted_average(&self) -> f64 {
110        let n = self.inputs.len();
111        let weights_slice = &self.weights[self.period - n..];
112
113        let mut sum = 0.0;
114        let mut weight_sum = 0.0;
115
116        for (input, weight) in self.inputs.iter().rev().zip(weights_slice.iter().rev()) {
117            sum += input * weight;
118            weight_sum += weight;
119        }
120        sum / weight_sum
121    }
122}
123
124impl Indicator for WeightedMovingAverage {
125    fn name(&self) -> String {
126        stringify!(WeightedMovingAverage).to_string()
127    }
128
129    fn has_inputs(&self) -> bool {
130        !self.inputs.is_empty()
131    }
132
133    fn initialized(&self) -> bool {
134        self.initialized
135    }
136
137    fn handle_quote(&mut self, quote: &QuoteTick) {
138        self.update_raw(quote.extract_price(self.price_type).into());
139    }
140
141    fn handle_trade(&mut self, trade: &TradeTick) {
142        self.update_raw((&trade.price).into());
143    }
144
145    fn handle_bar(&mut self, bar: &Bar) {
146        self.update_raw((&bar.close).into());
147    }
148
149    fn reset(&mut self) {
150        self.value = 0.0;
151        self.initialized = false;
152        self.inputs.clear();
153    }
154}
155
156impl MovingAverage for WeightedMovingAverage {
157    fn value(&self) -> f64 {
158        self.value
159    }
160
161    fn count(&self) -> usize {
162        self.inputs.len()
163    }
164
165    fn update_raw(&mut self, value: f64) {
166        if self.inputs.len() == self.period.min(MAX_PERIOD) {
167            self.inputs.pop_front();
168        }
169        let _ = self.inputs.push_back(value);
170
171        self.value = self.weighted_average();
172        self.initialized = self.count() >= self.period;
173    }
174}
175
176////////////////////////////////////////////////////////////////////////////////
177// Tests
178////////////////////////////////////////////////////////////////////////////////
179#[cfg(test)]
180mod tests {
181
182    use arraydeque::{ArrayDeque, Wrapping};
183    use rstest::rstest;
184
185    use crate::{
186        average::wma::WeightedMovingAverage,
187        indicator::{Indicator, MovingAverage},
188        stubs::*,
189    };
190
191    #[rstest]
192    fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
193        let display_str = format!("{indicator_wma_10}");
194        assert_eq!(
195            display_str,
196            "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
197        );
198        assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
199        assert!(!indicator_wma_10.has_inputs());
200        assert!(!indicator_wma_10.initialized());
201    }
202
203    #[rstest]
204    #[should_panic]
205    fn test_different_weights_len_and_period_error() {
206        let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
207    }
208
209    #[rstest]
210    fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
211        indicator_wma_10.update_raw(1.0);
212        assert_eq!(indicator_wma_10.value, 1.0);
213    }
214
215    #[rstest]
216    fn test_value_with_two_inputs_equal_weights() {
217        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
218        wma.update_raw(1.0);
219        wma.update_raw(2.0);
220        assert_eq!(wma.value, 1.5);
221    }
222
223    #[rstest]
224    fn test_value_with_four_inputs_equal_weights() {
225        let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
226        wma.update_raw(1.0);
227        wma.update_raw(2.0);
228        wma.update_raw(3.0);
229        wma.update_raw(4.0);
230        assert_eq!(wma.value, 2.5);
231    }
232
233    #[rstest]
234    fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
235        indicator_wma_10.update_raw(1.0);
236        indicator_wma_10.update_raw(2.0);
237        let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
238        assert_eq!(indicator_wma_10.value, result);
239    }
240
241    #[rstest]
242    fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
243        indicator_wma_10.update_raw(1.0);
244        indicator_wma_10.update_raw(2.0);
245        indicator_wma_10.update_raw(3.0);
246        let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
247        assert_eq!(indicator_wma_10.value, result);
248    }
249
250    #[rstest]
251    fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
252        for i in 1..11 {
253            indicator_wma_10.update_raw(f64::from(i));
254        }
255        assert_eq!(indicator_wma_10.value, 7.0);
256    }
257
258    #[rstest]
259    fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
260        for i in 1..=11 {
261            indicator_wma_10.update_raw(f64::from(i));
262        }
263        assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
264    }
265
266    #[rstest]
267    fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
268        indicator_wma_10.update_raw(1.0);
269        indicator_wma_10.update_raw(2.0);
270        indicator_wma_10.reset();
271        assert_eq!(indicator_wma_10.value, 0.0);
272        assert_eq!(indicator_wma_10.count(), 0);
273        assert!(!indicator_wma_10.initialized);
274    }
275
276    #[rstest]
277    #[should_panic]
278    fn new_panics_on_zero_period() {
279        let _ = WeightedMovingAverage::new(0, vec![1.0], None);
280    }
281
282    #[rstest]
283    fn new_checked_err_on_zero_period() {
284        let res = WeightedMovingAverage::new_checked(0, vec![1.0], None);
285        assert!(res.is_err());
286    }
287
288    #[rstest]
289    #[should_panic]
290    fn new_panics_on_zero_weight_sum() {
291        let _ = WeightedMovingAverage::new(3, vec![0.0, 0.0, 0.0], None);
292    }
293
294    #[rstest]
295    fn new_checked_err_on_zero_weight_sum() {
296        let res = WeightedMovingAverage::new_checked(3, vec![0.0, 0.0, 0.0], None);
297        assert!(res.is_err());
298    }
299
300    #[rstest]
301    #[should_panic]
302    fn new_panics_when_weight_sum_below_epsilon() {
303        let tiny = f64::EPSILON / 10.0;
304        let _ = WeightedMovingAverage::new(3, vec![tiny; 3], None);
305    }
306
307    #[rstest]
308    fn initialized_flag_transitions() {
309        let period = 3;
310        let weights = vec![1.0, 2.0, 3.0];
311        let mut wma = WeightedMovingAverage::new(period, weights, None);
312
313        assert!(!wma.initialized());
314
315        for i in 0..period {
316            wma.update_raw(i as f64);
317            let expected = (i + 1) >= period;
318            assert_eq!(wma.initialized(), expected);
319        }
320        assert!(wma.initialized());
321    }
322
323    #[rstest]
324    fn count_matches_inputs_and_has_inputs() {
325        let mut wma = WeightedMovingAverage::new(4, vec![0.25; 4], None);
326
327        assert_eq!(wma.count(), 0);
328        assert!(!wma.has_inputs());
329
330        wma.update_raw(1.0);
331        wma.update_raw(2.0);
332        assert_eq!(wma.count(), 2);
333        assert!(wma.has_inputs());
334    }
335
336    #[rstest]
337    fn reset_restores_pristine_state() {
338        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
339        wma.update_raw(1.0);
340        wma.update_raw(2.0);
341        assert!(wma.initialized());
342
343        wma.reset();
344
345        assert_eq!(wma.count(), 0);
346        assert_eq!(wma.value(), 0.0);
347        assert!(!wma.initialized());
348        assert!(!wma.has_inputs());
349    }
350
351    #[rstest]
352    fn weighted_average_with_non_uniform_weights() {
353        let mut wma = WeightedMovingAverage::new(3, vec![1.0, 2.0, 3.0], None);
354        wma.update_raw(10.0);
355        wma.update_raw(20.0);
356        wma.update_raw(30.0);
357        let expected = 23.333_333_333_333_332;
358        let tol = f64::EPSILON.sqrt();
359        assert!(
360            (wma.value() - expected).abs() < tol,
361            "value = {}, expected ≈ {}",
362            wma.value(),
363            expected
364        );
365    }
366
367    #[rstest]
368    fn test_window_never_exceeds_period(mut indicator_wma_10: WeightedMovingAverage) {
369        for i in 0..100 {
370            indicator_wma_10.update_raw(f64::from(i));
371            assert!(indicator_wma_10.count() <= indicator_wma_10.period);
372        }
373    }
374
375    #[rstest]
376    fn test_negative_weights_positive_sum() {
377        let period = 3;
378        let weights = vec![-1.0, 2.0, 2.0];
379        let mut wma = WeightedMovingAverage::new(period, weights, None);
380        wma.update_raw(1.0);
381        wma.update_raw(2.0);
382        wma.update_raw(3.0);
383
384        let expected = 2.0f64.mul_add(3.0, 2.0f64.mul_add(2.0, -1.0)) / 3.0;
385        let tol = f64::EPSILON.sqrt();
386        assert!((wma.value() - expected).abs() < tol);
387    }
388
389    #[rstest]
390    fn test_nan_input_propagates() {
391        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
392        wma.update_raw(1.0);
393        wma.update_raw(f64::NAN);
394
395        assert!(wma.value().is_nan());
396    }
397
398    #[rstest]
399    #[should_panic]
400    fn new_panics_when_weight_sum_equals_epsilon() {
401        let eps_third = f64::EPSILON / 3.0;
402        let _ = WeightedMovingAverage::new(3, vec![eps_third; 3], None);
403    }
404
405    #[rstest]
406    fn new_checked_err_when_weight_sum_equals_epsilon() {
407        let eps_third = f64::EPSILON / 3.0;
408        let res = WeightedMovingAverage::new_checked(3, vec![eps_third; 3], None);
409        assert!(res.is_err());
410    }
411
412    #[rstest]
413    fn new_checked_err_when_weight_sum_below_epsilon() {
414        let w = f64::EPSILON * 0.9;
415        let res = WeightedMovingAverage::new_checked(1, vec![w], None);
416        assert!(res.is_err());
417    }
418
419    #[rstest]
420    fn new_ok_when_weight_sum_above_epsilon() {
421        let w = f64::EPSILON * 1.1;
422        let res = WeightedMovingAverage::new_checked(1, vec![w], None);
423        assert!(res.is_ok());
424    }
425
426    #[rstest]
427    #[should_panic]
428    fn new_panics_on_cancelled_weights_sum() {
429        let _ = WeightedMovingAverage::new(3, vec![1.0, -1.0, 0.0], None);
430    }
431
432    #[rstest]
433    fn new_checked_err_on_cancelled_weights_sum() {
434        let res = WeightedMovingAverage::new_checked(3, vec![1.0, -1.0, 0.0], None);
435        assert!(res.is_err());
436    }
437
438    #[rstest]
439    fn single_period_returns_latest_input() {
440        let mut wma = WeightedMovingAverage::new(1, vec![1.0], None);
441        for i in 0..5 {
442            let v = f64::from(i);
443            wma.update_raw(v);
444            assert_eq!(wma.value(), v);
445        }
446    }
447
448    #[rstest]
449    fn value_with_sparse_weights() {
450        let mut wma = WeightedMovingAverage::new(3, vec![0.0, 1.0, 0.0], None);
451        wma.update_raw(10.0);
452        wma.update_raw(20.0);
453        wma.update_raw(30.0);
454        assert_eq!(wma.value(), 20.0);
455    }
456
457    #[rstest]
458    fn warm_up_len1() {
459        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
460        wma.update_raw(42.0);
461        assert_eq!(wma.value(), 42.0);
462    }
463
464    #[rstest]
465    fn warm_up_len2() {
466        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
467        wma.update_raw(10.0);
468        wma.update_raw(20.0);
469        let expected = 20.0f64.mul_add(4.0, 10.0 * 3.0) / (4.0 + 3.0);
470        assert_eq!(wma.value(), expected);
471    }
472
473    #[rstest]
474    fn warm_up_len3() {
475        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
476        wma.update_raw(1.0);
477        wma.update_raw(2.0);
478        wma.update_raw(3.0);
479        let expected = 1.0f64.mul_add(2.0, 3.0f64.mul_add(4.0, 2.0 * 3.0)) / (4.0 + 3.0 + 2.0);
480        assert_eq!(wma.value(), expected);
481    }
482
483    #[rstest]
484    fn input_window_contains_latest_period() {
485        let period = 3;
486        let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
487        let vals = [1.0, 2.0, 3.0, 4.0];
488        for v in vals {
489            wma.update_raw(v);
490        }
491        let expected: Vec<f64> = vals[vals.len() - period..].to_vec();
492        assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), expected);
493    }
494
495    #[rstest]
496    fn window_slides_correctly() {
497        let mut wma = WeightedMovingAverage::new(2, vec![1.0; 2], None);
498        wma.update_raw(1.0);
499        assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), vec![1.0]);
500        wma.update_raw(2.0);
501        assert_eq!(
502            wma.inputs.iter().copied().collect::<Vec<_>>(),
503            vec![1.0, 2.0]
504        );
505        wma.update_raw(3.0);
506        assert_eq!(
507            wma.inputs.iter().copied().collect::<Vec<_>>(),
508            vec![2.0, 3.0]
509        );
510    }
511
512    #[rstest]
513    fn window_len_constant_after_many_updates() {
514        let period = 5;
515        let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
516        for i in 0..100 {
517            wma.update_raw(i as f64);
518            assert_eq!(wma.inputs.len(), period.min(i + 1));
519        }
520    }
521
522    #[rstest]
523    fn arraydeque_wraps_when_full() {
524        const CAP: usize = 3;
525        let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
526        for i in 0..=CAP {
527            let _ = buf.push_back(i);
528        }
529        assert_eq!(buf.len(), CAP);
530        assert_eq!(buf.front().copied(), Some(1));
531        assert_eq!(buf.back().copied(), Some(3));
532    }
533
534    #[rstest]
535    fn arraydeque_sliding_window_with_pop() {
536        const CAP: usize = 3;
537        let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
538        for i in 0..10 {
539            if buf.len() == CAP {
540                buf.pop_front();
541            }
542            let _ = buf.push_back(i);
543            assert!(buf.len() <= CAP);
544        }
545        assert_eq!(buf.len(), CAP);
546    }
547
548    #[rstest]
549    fn new_ok_with_infinite_weight() {
550        let res = WeightedMovingAverage::new_checked(2, vec![f64::INFINITY, 1.0], None);
551        assert!(res.is_ok());
552    }
553
554    #[rstest]
555    #[should_panic]
556    fn new_panics_on_nan_weight() {
557        let _ = WeightedMovingAverage::new(2, vec![f64::NAN, 1.0], None);
558    }
559
560    #[rstest]
561    #[should_panic]
562    fn new_panics_on_empty_weights() {
563        let _ = WeightedMovingAverage::new(1, Vec::new(), None);
564    }
565
566    #[rstest]
567    fn inf_input_propagates() {
568        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
569        wma.update_raw(1.0);
570        wma.update_raw(f64::INFINITY);
571        assert!(wma.value().is_infinite());
572    }
573
574    #[rstest]
575    fn warm_up_with_front_zero_weights() {
576        let mut wma = WeightedMovingAverage::new(4, vec![0.0, 0.0, 1.0, 1.0], None);
577        wma.update_raw(10.0);
578        wma.update_raw(20.0);
579        let expected = 20.0f64.mul_add(1.0, 10.0 * 1.0) / 2.0;
580        assert_eq!(wma.value(), expected);
581    }
582}