nautilus_indicators/average/
lr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct LinearRegression {
32    pub period: usize,
33    pub slope: f64,
34    pub intercept: f64,
35    pub degree: f64,
36    pub cfo: f64,
37    pub r2: f64,
38    pub value: f64,
39    pub initialized: bool,
40    has_inputs: bool,
41    inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
42    x_sum: f64,
43    x_mul_sum: f64,
44    divisor: f64,
45}
46
47impl Display for LinearRegression {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}({})", self.name(), self.period)
50    }
51}
52
53impl Indicator for LinearRegression {
54    fn name(&self) -> String {
55        stringify!(LinearRegression).into()
56    }
57
58    fn has_inputs(&self) -> bool {
59        self.has_inputs
60    }
61
62    fn initialized(&self) -> bool {
63        self.initialized
64    }
65
66    fn handle_bar(&mut self, bar: &Bar) {
67        self.update_raw(bar.close.into());
68    }
69
70    fn reset(&mut self) {
71        self.slope = 0.0;
72        self.intercept = 0.0;
73        self.degree = 0.0;
74        self.cfo = 0.0;
75        self.r2 = 0.0;
76        self.value = 0.0;
77        self.inputs.clear();
78        self.has_inputs = false;
79        self.initialized = false;
80    }
81}
82
83impl LinearRegression {
84    /// Creates a new [`LinearRegression`] instance.
85    ///
86    /// # Panics
87    ///
88    /// This function panics if:
89    /// `period` is zero.
90    /// `period` exceeds `MAX_PERIOD` (16,384).
91    #[must_use]
92    pub fn new(period: usize) -> Self {
93        assert!(
94            period > 0,
95            "LinearRegression: period must be > 0 (received {period})"
96        );
97        assert!(
98            period <= MAX_PERIOD,
99            "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
100        );
101
102        let n = period as f64;
103        let x_sum = 0.5 * n * (n + 1.0);
104        let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
105        let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
106
107        Self {
108            period,
109            slope: 0.0,
110            intercept: 0.0,
111            degree: 0.0,
112            cfo: 0.0,
113            r2: 0.0,
114            value: 0.0,
115            initialized: false,
116            has_inputs: false,
117            inputs: ArrayDeque::new(),
118            x_sum,
119            x_mul_sum,
120            divisor,
121        }
122    }
123
124    /// Updates the linear regression with a new data point.
125    ///
126    /// # Panics
127    ///
128    /// Panics if called with an empty window – this is protected against by the logic
129    /// that returns early until enough samples have been collected.
130    pub fn update_raw(&mut self, close: f64) {
131        if self.inputs.len() == self.period {
132            let _ = self.inputs.pop_front();
133        }
134        let _ = self.inputs.push_back(close);
135
136        self.has_inputs = true;
137        if self.inputs.len() < self.period {
138            return;
139        }
140        self.initialized = true;
141
142        let n = self.period as f64;
143        let x_sum = self.x_sum;
144        let x_mul_sum = self.x_mul_sum;
145        let divisor = self.divisor;
146
147        let (mut y_sum, mut xy_sum) = (0.0, 0.0);
148        for (i, &y) in self.inputs.iter().enumerate() {
149            let x = (i + 1) as f64;
150            y_sum += y;
151            xy_sum += x * y;
152        }
153
154        self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
155        self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
156
157        let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
158        for (i, &y) in self.inputs.iter().enumerate() {
159            let x = (i + 1) as f64;
160            let y_hat = self.slope.mul_add(x, self.intercept);
161            let resid = y_hat - y;
162            sse += resid * resid;
163            y_last = y;
164            e_last = resid;
165        }
166
167        self.value = y_last + e_last;
168        self.degree = self.slope.atan().to_degrees();
169        self.cfo = if y_last == 0.0 {
170            f64::NAN
171        } else {
172            100.0 * e_last / y_last
173        };
174
175        let mean = y_sum / n;
176        let sst: f64 = self
177            .inputs
178            .iter()
179            .map(|&y| {
180                let d = y - mean;
181                d * d
182            })
183            .sum();
184
185        self.r2 = if sst.abs() < f64::EPSILON {
186            f64::NAN
187        } else {
188            1.0 - sse / sst
189        };
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use nautilus_model::data::Bar;
196    use rstest::rstest;
197
198    use super::*;
199    use crate::{
200        average::lr::LinearRegression,
201        indicator::Indicator,
202        stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
203    };
204
205    #[rstest]
206    fn test_psl_initialized(indicator_lr_10: LinearRegression) {
207        let display_str = format!("{indicator_lr_10}");
208        assert_eq!(display_str, "LinearRegression(10)");
209        assert_eq!(indicator_lr_10.period, 10);
210        assert!(!indicator_lr_10.initialized);
211        assert!(!indicator_lr_10.has_inputs);
212    }
213
214    #[rstest]
215    #[should_panic(expected = "LinearRegression: period must be > 0")]
216    fn test_new_with_zero_period_panics() {
217        let _ = LinearRegression::new(0);
218    }
219
220    #[rstest]
221    fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
222        indicator_lr_10.update_raw(1.0);
223        assert_eq!(indicator_lr_10.value, 0.0);
224    }
225
226    #[rstest]
227    fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
228        indicator_lr_10.update_raw(1.0);
229        indicator_lr_10.update_raw(2.0);
230        indicator_lr_10.update_raw(3.0);
231        assert_eq!(indicator_lr_10.value, 0.0);
232    }
233
234    #[rstest]
235    fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
236        for i in 1..10 {
237            indicator_lr_10.update_raw(f64::from(i));
238        }
239        assert!(!indicator_lr_10.initialized);
240        indicator_lr_10.update_raw(10.0);
241        assert!(indicator_lr_10.initialized);
242    }
243
244    #[rstest]
245    fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
246        indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
247        assert_eq!(indicator_lr_10.value, 0.0);
248        assert!(indicator_lr_10.has_inputs);
249        assert!(!indicator_lr_10.initialized);
250    }
251
252    #[rstest]
253    fn test_reset(mut indicator_lr_10: LinearRegression) {
254        indicator_lr_10.update_raw(1.0);
255        indicator_lr_10.reset();
256        assert_eq!(indicator_lr_10.value, 0.0);
257        assert_eq!(indicator_lr_10.inputs.len(), 0);
258        assert_eq!(indicator_lr_10.slope, 0.0);
259        assert_eq!(indicator_lr_10.intercept, 0.0);
260        assert_eq!(indicator_lr_10.degree, 0.0);
261        assert_eq!(indicator_lr_10.cfo, 0.0);
262        assert_eq!(indicator_lr_10.r2, 0.0);
263        assert!(!indicator_lr_10.has_inputs);
264        assert!(!indicator_lr_10.initialized);
265    }
266
267    #[rstest]
268    fn test_inputs_len_never_exceeds_period() {
269        let mut lr = LinearRegression::new(3);
270        for i in 0..10 {
271            lr.update_raw(f64::from(i));
272        }
273        assert_eq!(lr.inputs.len(), lr.period);
274    }
275
276    #[rstest]
277    fn test_oldest_element_evicted() {
278        let mut lr = LinearRegression::new(4);
279        for v in 1..=5 {
280            lr.update_raw(f64::from(v));
281        }
282        assert!(!lr.inputs.contains(&1.0));
283        assert_eq!(lr.inputs.front(), Some(&2.0));
284    }
285
286    #[rstest]
287    fn test_recent_elements_preserved() {
288        let mut lr = LinearRegression::new(5);
289        for v in 0..5 {
290            lr.update_raw(f64::from(v));
291        }
292        lr.update_raw(99.0);
293        let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
294        assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
295    }
296
297    #[rstest]
298    fn test_multiple_evictions() {
299        let mut lr = LinearRegression::new(2);
300        lr.update_raw(10.0);
301        lr.update_raw(20.0);
302        lr.update_raw(30.0);
303        lr.update_raw(40.0);
304        assert_eq!(
305            lr.inputs.iter().copied().collect::<Vec<_>>(),
306            vec![30.0, 40.0]
307        );
308    }
309
310    #[rstest]
311    fn test_value_stable_after_eviction() {
312        let mut lr = LinearRegression::new(3);
313        lr.update_raw(1.0);
314        lr.update_raw(2.0);
315        lr.update_raw(3.0);
316        let before = lr.value;
317        lr.update_raw(4.0);
318        let after = lr.value;
319        assert!(after.is_finite());
320        assert_ne!(before, after);
321    }
322
323    #[rstest]
324    fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
325        indicator_lr_10.update_raw(1.00000);
326        indicator_lr_10.update_raw(1.00010);
327        indicator_lr_10.update_raw(1.00030);
328        indicator_lr_10.update_raw(1.00040);
329        indicator_lr_10.update_raw(1.00050);
330        indicator_lr_10.update_raw(1.00060);
331        indicator_lr_10.update_raw(1.00050);
332        indicator_lr_10.update_raw(1.00040);
333        indicator_lr_10.update_raw(1.00030);
334        indicator_lr_10.update_raw(1.00010);
335        indicator_lr_10.update_raw(1.00000);
336
337        assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
338    }
339
340    #[rstest]
341    fn r2_nan_for_constant_series() {
342        let mut lr = LinearRegression::new(5);
343        for _ in 0..5 {
344            lr.update_raw(42.0);
345        }
346        assert!(lr.initialized);
347        assert!(
348            lr.r2.is_nan(),
349            "R² should be NaN for a constant-value input series"
350        );
351    }
352
353    #[rstest]
354    fn cfo_nan_when_last_price_zero() {
355        let mut lr = LinearRegression::new(3);
356        lr.update_raw(1.0);
357        lr.update_raw(2.0);
358        lr.update_raw(0.0);
359        assert!(lr.initialized);
360        assert!(
361            lr.cfo.is_nan(),
362            "CFO should be NaN when the most-recent price equals zero"
363        );
364    }
365
366    #[rstest]
367    fn positive_slope_and_degree_for_uptrend() {
368        let mut lr = LinearRegression::new(4);
369        for v in 1..=4 {
370            lr.update_raw(f64::from(v));
371        }
372        assert!(lr.slope > 0.0, "slope expected positive for up-trend");
373        assert!(lr.degree > 0.0, "degree expected positive for up-trend");
374    }
375
376    #[rstest]
377    fn negative_slope_and_degree_for_downtrend() {
378        let mut lr = LinearRegression::new(4);
379        for v in (1..=4).rev() {
380            lr.update_raw(f64::from(v));
381        }
382        assert!(lr.slope < 0.0, "slope expected negative for down-trend");
383        assert!(lr.degree < 0.0, "degree expected negative for down-trend");
384    }
385
386    #[rstest]
387    fn not_initialized_until_enough_samples() {
388        let mut lr = LinearRegression::new(6);
389        for v in 0..5 {
390            lr.update_raw(f64::from(v));
391        }
392        assert!(
393            !lr.initialized,
394            "indicator should remain uninitialised with fewer than `period` inputs"
395        );
396    }
397
398    #[rstest]
399    #[case(128)]
400    #[case(1_024)]
401    #[case(16_384)]
402    fn large_period_initialisation_and_window_size(#[case] period: usize) {
403        let mut lr = LinearRegression::new(period);
404        for v in 0..period {
405            lr.update_raw(v as f64);
406        }
407        assert!(
408            lr.initialized,
409            "indicator should initialise after exactly `period` samples"
410        );
411        assert_eq!(
412            lr.inputs.len(),
413            period,
414            "internal window length must equal the configured period"
415        );
416    }
417
418    #[rstest]
419    fn cached_constants_correct() {
420        let period = 10;
421        let lr = LinearRegression::new(period);
422
423        let n = period as f64;
424        let expected_x_sum = 0.5 * n * (n + 1.0);
425        let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
426        let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
427
428        assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
429        assert!(
430            (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
431            "x_mul_sum mismatch"
432        );
433        assert!(
434            (lr.divisor - expected_divisor).abs() < 1e-12,
435            "divisor mismatch"
436        );
437    }
438
439    #[rstest]
440    fn cached_constants_immutable_through_updates() {
441        let mut lr = LinearRegression::new(5);
442
443        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
444
445        for v in 0..20 {
446            lr.update_raw(f64::from(v));
447        }
448
449        assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
450        assert_eq!(
451            lr.x_mul_sum, x_mul_sum,
452            "x_mul_sum must remain unchanged after updates"
453        );
454        assert_eq!(
455            lr.divisor, divisor,
456            "divisor must remain unchanged after updates"
457        );
458    }
459
460    #[rstest]
461    fn cached_constants_immutable_after_reset() {
462        let mut lr = LinearRegression::new(8);
463
464        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
465
466        for v in 0..8 {
467            lr.update_raw(f64::from(v));
468        }
469        lr.reset();
470
471        assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
472        assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
473        assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
474    }
475
476    const EPS: f64 = 1e-12;
477
478    #[rstest]
479    #[should_panic]
480    fn new_zero_period_panics() {
481        let _ = LinearRegression::new(0);
482    }
483
484    #[rstest]
485    #[should_panic]
486    fn new_period_exceeds_max_panics() {
487        let _ = LinearRegression::new(MAX_PERIOD + 1);
488    }
489
490    #[rstest(
491        period, value,
492        case(8, 5.0),
493        case(16, -std::f64::consts::PI)
494    )]
495    fn constant_non_zero_series(period: usize, value: f64) {
496        let mut lr = LinearRegression::new(period);
497
498        for _ in 0..period {
499            lr.update_raw(value);
500        }
501
502        assert!(lr.initialized());
503        assert!(lr.slope.abs() < EPS);
504        assert!((lr.intercept - value).abs() < EPS);
505        assert!(lr.degree.abs() < EPS);
506        assert!(lr.r2.is_nan());
507        assert!((lr.cfo).abs() < EPS);
508        assert!((lr.value - value).abs() < EPS);
509    }
510
511    #[rstest(period, case(4), case(32))]
512    fn constant_zero_series_cfo_nan(period: usize) {
513        let mut lr = LinearRegression::new(period);
514
515        for _ in 0..period {
516            lr.update_raw(0.0);
517        }
518
519        assert!(lr.initialized());
520        assert!(lr.cfo.is_nan());
521    }
522
523    #[rstest(period, case(6), case(13))]
524    fn reset_clears_state_but_keeps_constants(period: usize) {
525        let mut lr = LinearRegression::new(period);
526
527        for i in 1..=period {
528            lr.update_raw(i as f64);
529        }
530
531        let x_sum_before = lr.x_sum;
532        let x_mul_sum_before = lr.x_mul_sum;
533        let divisor_before = lr.divisor;
534
535        lr.reset();
536
537        assert!(!lr.initialized());
538        assert!(!lr.has_inputs());
539
540        assert!(lr.slope.abs() < EPS);
541        assert!(lr.intercept.abs() < EPS);
542        assert!(lr.degree.abs() < EPS);
543        assert!(lr.cfo.abs() < EPS);
544        assert!(lr.r2.abs() < EPS);
545        assert!(lr.value.abs() < EPS);
546
547        assert_eq!(lr.x_sum, x_sum_before);
548        assert_eq!(lr.x_mul_sum, x_mul_sum_before);
549        assert_eq!(lr.divisor, divisor_before);
550    }
551
552    #[rstest(period, case(5), case(31))]
553    fn perfect_linear_series(period: usize) {
554        const A: f64 = 2.0;
555        const B: f64 = -3.0;
556        let mut lr = LinearRegression::new(period);
557
558        for x in 1..=period {
559            lr.update_raw(A.mul_add(x as f64, B));
560        }
561
562        assert!(lr.initialized());
563        assert!((lr.slope - A).abs() < EPS);
564        assert!((lr.intercept - B).abs() < EPS);
565        assert!((lr.r2 - 1.0).abs() < EPS);
566        assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
567    }
568
569    #[rstest]
570    fn sliding_window_keeps_last_period() {
571        const P: usize = 4;
572        let mut lr = LinearRegression::new(P);
573        for i in 1..=P {
574            lr.update_raw(i as f64);
575        }
576        let slope_first_window = lr.slope;
577
578        lr.update_raw(-100.0);
579        assert!(lr.slope < slope_first_window);
580        assert_eq!(lr.inputs.len(), P);
581        assert_eq!(lr.inputs.front(), Some(&2.0));
582    }
583
584    #[rstest]
585    fn r2_between_zero_and_one() {
586        const P: usize = 32;
587        let mut lr = LinearRegression::new(P);
588        for x in 1..=P {
589            let noise = if x.is_multiple_of(2) { 0.5 } else { -0.5 };
590            lr.update_raw(3.0f64.mul_add(x as f64, noise));
591        }
592        assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
593    }
594
595    #[rstest]
596    fn reset_before_initialized() {
597        let mut lr = LinearRegression::new(10);
598        lr.update_raw(1.0);
599        lr.reset();
600
601        assert!(!lr.initialized());
602        assert!(!lr.has_inputs());
603        assert_eq!(lr.inputs.len(), 0);
604    }
605}