nautilus_indicators/average/
wma.rs1use std::fmt::Display;
17
18use nautilus_core::correctness::{FAILED, check_predicate_true};
19use nautilus_model::{
20 data::{Bar, QuoteTick, TradeTick},
21 enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26#[repr(C)]
28#[derive(Debug)]
29#[cfg_attr(
30 feature = "python",
31 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
32)]
33pub struct WeightedMovingAverage {
34 pub period: usize,
36 pub weights: Vec<f64>,
38 pub price_type: PriceType,
40 pub value: f64,
42 pub initialized: bool,
44 pub inputs: Vec<f64>,
46 has_inputs: bool,
47}
48
49impl Display for WeightedMovingAverage {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
52 }
53}
54
55impl WeightedMovingAverage {
56 #[must_use]
58 pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
59 Self::new_checked(period, weights, price_type).expect(FAILED)
60 }
61
62 pub fn new_checked(
63 period: usize,
64 weights: Vec<f64>,
65 price_type: Option<PriceType>,
66 ) -> anyhow::Result<Self> {
67 check_predicate_true(
68 period == weights.len(),
69 "`period` must be equal to `weights` length",
70 )?;
71
72 Ok(Self {
73 period,
74 weights,
75 price_type: price_type.unwrap_or(PriceType::Last),
76 value: 0.0,
77 inputs: Vec::with_capacity(period),
78 initialized: false,
79 has_inputs: false,
80 })
81 }
82
83 fn weighted_average(&self) -> f64 {
84 let mut sum = 0.0;
85 let mut weight_sum = 0.0;
86 let reverse_weights: Vec<f64> = self.weights.iter().copied().rev().collect();
87 for (index, input) in self.inputs.iter().rev().enumerate() {
88 let weight = reverse_weights.get(index).unwrap();
89 sum += input * weight;
90 weight_sum += weight;
91 }
92 sum / weight_sum
93 }
94}
95
96impl Indicator for WeightedMovingAverage {
97 fn name(&self) -> String {
98 stringify!(WeightedMovingAverage).to_string()
99 }
100
101 fn has_inputs(&self) -> bool {
102 self.has_inputs
103 }
104 fn initialized(&self) -> bool {
105 self.initialized
106 }
107
108 fn handle_quote(&mut self, quote: &QuoteTick) {
109 self.update_raw(quote.extract_price(self.price_type).into());
110 }
111
112 fn handle_trade(&mut self, trade: &TradeTick) {
113 self.update_raw((&trade.price).into());
114 }
115
116 fn handle_bar(&mut self, bar: &Bar) {
117 self.update_raw((&bar.close).into());
118 }
119
120 fn reset(&mut self) {
121 self.value = 0.0;
122 self.has_inputs = false;
123 self.initialized = false;
124 self.inputs.clear();
125 }
126}
127
128impl MovingAverage for WeightedMovingAverage {
129 fn value(&self) -> f64 {
130 self.value
131 }
132
133 fn count(&self) -> usize {
134 self.inputs.len()
135 }
136 fn update_raw(&mut self, value: f64) {
137 if !self.has_inputs {
138 self.has_inputs = true;
139 self.inputs.push(value);
140 self.value = value;
141 return;
142 }
143 if self.inputs.len() == self.period {
144 self.inputs.remove(0);
145 }
146 self.inputs.push(value);
147 self.value = self.weighted_average();
148 if !self.initialized && self.count() >= self.period {
149 self.initialized = true;
150 }
151 }
152}
153
154#[cfg(test)]
158mod tests {
159 use rstest::rstest;
160
161 use crate::{
162 average::wma::WeightedMovingAverage,
163 indicator::{Indicator, MovingAverage},
164 stubs::*,
165 };
166
167 #[rstest]
168 fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
169 let display_str = format!("{indicator_wma_10}");
170 assert_eq!(
171 display_str,
172 "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
173 );
174 assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
175 assert!(!indicator_wma_10.has_inputs());
176 assert!(!indicator_wma_10.initialized());
177 }
178
179 #[rstest]
180 #[should_panic]
181 fn test_different_weights_len_and_period_error() {
182 let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
183 }
184
185 #[rstest]
186 fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
187 indicator_wma_10.update_raw(1.0);
188 assert_eq!(indicator_wma_10.value, 1.0);
189 }
190
191 #[rstest]
192 fn test_value_with_two_inputs_equal_weights() {
193 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
194 wma.update_raw(1.0);
195 wma.update_raw(2.0);
196 assert_eq!(wma.value, 1.5);
197 }
198
199 #[rstest]
200 fn test_value_with_four_inputs_equal_weights() {
201 let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
202 wma.update_raw(1.0);
203 wma.update_raw(2.0);
204 wma.update_raw(3.0);
205 wma.update_raw(4.0);
206 assert_eq!(wma.value, 2.5);
207 }
208
209 #[rstest]
210 fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
211 indicator_wma_10.update_raw(1.0);
212 indicator_wma_10.update_raw(2.0);
213 let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
214 assert_eq!(indicator_wma_10.value, result);
215 }
216
217 #[rstest]
218 fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
219 indicator_wma_10.update_raw(1.0);
220 indicator_wma_10.update_raw(2.0);
221 indicator_wma_10.update_raw(3.0);
222 let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
223 assert_eq!(indicator_wma_10.value, result);
224 }
225
226 #[rstest]
227 fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
228 for i in 1..11 {
229 indicator_wma_10.update_raw(f64::from(i));
230 }
231 assert_eq!(indicator_wma_10.value, 7.0);
232 }
233
234 #[rstest]
235 fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
236 for i in 1..=11 {
237 indicator_wma_10.update_raw(f64::from(i));
238 }
239 assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
240 }
241
242 #[rstest]
243 fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
244 indicator_wma_10.update_raw(1.0);
245 indicator_wma_10.update_raw(2.0);
246 indicator_wma_10.reset();
247 assert_eq!(indicator_wma_10.value, 0.0);
248 assert_eq!(indicator_wma_10.count(), 0);
249 assert!(!indicator_wma_10.has_inputs);
250 assert!(!indicator_wma_10.initialized);
251 }
252}