nautilus_indicators/average/
lr.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct LinearRegression {
32 pub period: usize,
33 pub slope: f64,
34 pub intercept: f64,
35 pub degree: f64,
36 pub cfo: f64,
37 pub r2: f64,
38 pub value: f64,
39 pub initialized: bool,
40 has_inputs: bool,
41 inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
42 x_sum: f64,
43 x_mul_sum: f64,
44 divisor: f64,
45}
46
47impl Display for LinearRegression {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}({})", self.name(), self.period)
50 }
51}
52
53impl Indicator for LinearRegression {
54 fn name(&self) -> String {
55 stringify!(LinearRegression).into()
56 }
57
58 fn has_inputs(&self) -> bool {
59 self.has_inputs
60 }
61
62 fn initialized(&self) -> bool {
63 self.initialized
64 }
65
66 fn handle_bar(&mut self, bar: &Bar) {
67 self.update_raw(bar.close.into());
68 }
69
70 fn reset(&mut self) {
71 self.slope = 0.0;
72 self.intercept = 0.0;
73 self.degree = 0.0;
74 self.cfo = 0.0;
75 self.r2 = 0.0;
76 self.value = 0.0;
77 self.inputs.clear();
78 self.has_inputs = false;
79 self.initialized = false;
80 }
81}
82
83impl LinearRegression {
84 #[must_use]
92 pub fn new(period: usize) -> Self {
93 assert!(
94 period > 0,
95 "LinearRegression: period must be > 0 (received {period})"
96 );
97 assert!(
98 period <= MAX_PERIOD,
99 "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
100 );
101
102 let n = period as f64;
103 let x_sum = 0.5 * n * (n + 1.0);
104 let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
105 let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
106
107 Self {
108 period,
109 slope: 0.0,
110 intercept: 0.0,
111 degree: 0.0,
112 cfo: 0.0,
113 r2: 0.0,
114 value: 0.0,
115 initialized: false,
116 has_inputs: false,
117 inputs: ArrayDeque::new(),
118 x_sum,
119 x_mul_sum,
120 divisor,
121 }
122 }
123
124 pub fn update_raw(&mut self, close: f64) {
131 if self.inputs.len() == self.period {
132 let _ = self.inputs.pop_front();
133 }
134 let _ = self.inputs.push_back(close);
135
136 self.has_inputs = true;
137 if self.inputs.len() < self.period {
138 return;
139 }
140 self.initialized = true;
141
142 let n = self.period as f64;
143 let x_sum = self.x_sum;
144 let x_mul_sum = self.x_mul_sum;
145 let divisor = self.divisor;
146
147 let (mut y_sum, mut xy_sum) = (0.0, 0.0);
148 for (i, &y) in self.inputs.iter().enumerate() {
149 let x = (i + 1) as f64;
150 y_sum += y;
151 xy_sum += x * y;
152 }
153
154 self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
155 self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
156
157 let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
158 for (i, &y) in self.inputs.iter().enumerate() {
159 let x = (i + 1) as f64;
160 let y_hat = self.slope.mul_add(x, self.intercept);
161 let resid = y_hat - y;
162 sse += resid * resid;
163 y_last = y;
164 e_last = resid;
165 }
166
167 self.value = y_last + e_last;
168 self.degree = self.slope.atan().to_degrees();
169 self.cfo = if y_last == 0.0 {
170 f64::NAN
171 } else {
172 100.0 * e_last / y_last
173 };
174
175 let mean = y_sum / n;
176 let sst: f64 = self
177 .inputs
178 .iter()
179 .map(|&y| {
180 let d = y - mean;
181 d * d
182 })
183 .sum();
184
185 self.r2 = if sst.abs() < f64::EPSILON {
186 f64::NAN
187 } else {
188 1.0 - sse / sst
189 };
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use nautilus_model::data::Bar;
196 use rstest::rstest;
197
198 use super::*;
199 use crate::{
200 average::lr::LinearRegression,
201 indicator::Indicator,
202 stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
203 };
204
205 #[rstest]
206 fn test_psl_initialized(indicator_lr_10: LinearRegression) {
207 let display_str = format!("{indicator_lr_10}");
208 assert_eq!(display_str, "LinearRegression(10)");
209 assert_eq!(indicator_lr_10.period, 10);
210 assert!(!indicator_lr_10.initialized);
211 assert!(!indicator_lr_10.has_inputs);
212 }
213
214 #[rstest]
215 #[should_panic(expected = "LinearRegression: period must be > 0")]
216 fn test_new_with_zero_period_panics() {
217 let _ = LinearRegression::new(0);
218 }
219
220 #[rstest]
221 fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
222 indicator_lr_10.update_raw(1.0);
223 assert_eq!(indicator_lr_10.value, 0.0);
224 }
225
226 #[rstest]
227 fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
228 indicator_lr_10.update_raw(1.0);
229 indicator_lr_10.update_raw(2.0);
230 indicator_lr_10.update_raw(3.0);
231 assert_eq!(indicator_lr_10.value, 0.0);
232 }
233
234 #[rstest]
235 fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
236 for i in 1..10 {
237 indicator_lr_10.update_raw(f64::from(i));
238 }
239 assert!(!indicator_lr_10.initialized);
240 indicator_lr_10.update_raw(10.0);
241 assert!(indicator_lr_10.initialized);
242 }
243
244 #[rstest]
245 fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
246 indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
247 assert_eq!(indicator_lr_10.value, 0.0);
248 assert!(indicator_lr_10.has_inputs);
249 assert!(!indicator_lr_10.initialized);
250 }
251
252 #[rstest]
253 fn test_reset(mut indicator_lr_10: LinearRegression) {
254 indicator_lr_10.update_raw(1.0);
255 indicator_lr_10.reset();
256 assert_eq!(indicator_lr_10.value, 0.0);
257 assert_eq!(indicator_lr_10.inputs.len(), 0);
258 assert_eq!(indicator_lr_10.slope, 0.0);
259 assert_eq!(indicator_lr_10.intercept, 0.0);
260 assert_eq!(indicator_lr_10.degree, 0.0);
261 assert_eq!(indicator_lr_10.cfo, 0.0);
262 assert_eq!(indicator_lr_10.r2, 0.0);
263 assert!(!indicator_lr_10.has_inputs);
264 assert!(!indicator_lr_10.initialized);
265 }
266
267 #[rstest]
268 fn test_inputs_len_never_exceeds_period() {
269 let mut lr = LinearRegression::new(3);
270 for i in 0..10 {
271 lr.update_raw(f64::from(i));
272 }
273 assert_eq!(lr.inputs.len(), lr.period);
274 }
275
276 #[rstest]
277 fn test_oldest_element_evicted() {
278 let mut lr = LinearRegression::new(4);
279 for v in 1..=5 {
280 lr.update_raw(f64::from(v));
281 }
282 assert!(!lr.inputs.contains(&1.0));
283 assert_eq!(lr.inputs.front(), Some(&2.0));
284 }
285
286 #[rstest]
287 fn test_recent_elements_preserved() {
288 let mut lr = LinearRegression::new(5);
289 for v in 0..5 {
290 lr.update_raw(f64::from(v));
291 }
292 lr.update_raw(99.0);
293 let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
294 assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
295 }
296
297 #[rstest]
298 fn test_multiple_evictions() {
299 let mut lr = LinearRegression::new(2);
300 lr.update_raw(10.0);
301 lr.update_raw(20.0);
302 lr.update_raw(30.0);
303 lr.update_raw(40.0);
304 assert_eq!(
305 lr.inputs.iter().copied().collect::<Vec<_>>(),
306 vec![30.0, 40.0]
307 );
308 }
309
310 #[rstest]
311 fn test_value_stable_after_eviction() {
312 let mut lr = LinearRegression::new(3);
313 lr.update_raw(1.0);
314 lr.update_raw(2.0);
315 lr.update_raw(3.0);
316 let before = lr.value;
317 lr.update_raw(4.0);
318 let after = lr.value;
319 assert!(after.is_finite());
320 assert_ne!(before, after);
321 }
322
323 #[rstest]
324 fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
325 indicator_lr_10.update_raw(1.00000);
326 indicator_lr_10.update_raw(1.00010);
327 indicator_lr_10.update_raw(1.00030);
328 indicator_lr_10.update_raw(1.00040);
329 indicator_lr_10.update_raw(1.00050);
330 indicator_lr_10.update_raw(1.00060);
331 indicator_lr_10.update_raw(1.00050);
332 indicator_lr_10.update_raw(1.00040);
333 indicator_lr_10.update_raw(1.00030);
334 indicator_lr_10.update_raw(1.00010);
335 indicator_lr_10.update_raw(1.00000);
336
337 assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
338 }
339
340 #[rstest]
341 fn r2_nan_for_constant_series() {
342 let mut lr = LinearRegression::new(5);
343 for _ in 0..5 {
344 lr.update_raw(42.0);
345 }
346 assert!(lr.initialized);
347 assert!(
348 lr.r2.is_nan(),
349 "R² should be NaN for a constant-value input series"
350 );
351 }
352
353 #[rstest]
354 fn cfo_nan_when_last_price_zero() {
355 let mut lr = LinearRegression::new(3);
356 lr.update_raw(1.0);
357 lr.update_raw(2.0);
358 lr.update_raw(0.0);
359 assert!(lr.initialized);
360 assert!(
361 lr.cfo.is_nan(),
362 "CFO should be NaN when the most-recent price equals zero"
363 );
364 }
365
366 #[rstest]
367 fn positive_slope_and_degree_for_uptrend() {
368 let mut lr = LinearRegression::new(4);
369 for v in 1..=4 {
370 lr.update_raw(f64::from(v));
371 }
372 assert!(lr.slope > 0.0, "slope expected positive for up-trend");
373 assert!(lr.degree > 0.0, "degree expected positive for up-trend");
374 }
375
376 #[rstest]
377 fn negative_slope_and_degree_for_downtrend() {
378 let mut lr = LinearRegression::new(4);
379 for v in (1..=4).rev() {
380 lr.update_raw(f64::from(v));
381 }
382 assert!(lr.slope < 0.0, "slope expected negative for down-trend");
383 assert!(lr.degree < 0.0, "degree expected negative for down-trend");
384 }
385
386 #[rstest]
387 fn not_initialized_until_enough_samples() {
388 let mut lr = LinearRegression::new(6);
389 for v in 0..5 {
390 lr.update_raw(f64::from(v));
391 }
392 assert!(
393 !lr.initialized,
394 "indicator should remain uninitialised with fewer than `period` inputs"
395 );
396 }
397
398 #[rstest]
399 #[case(128)]
400 #[case(1_024)]
401 #[case(16_384)]
402 fn large_period_initialisation_and_window_size(#[case] period: usize) {
403 let mut lr = LinearRegression::new(period);
404 for v in 0..period {
405 lr.update_raw(v as f64);
406 }
407 assert!(
408 lr.initialized,
409 "indicator should initialise after exactly `period` samples"
410 );
411 assert_eq!(
412 lr.inputs.len(),
413 period,
414 "internal window length must equal the configured period"
415 );
416 }
417
418 #[rstest]
419 fn cached_constants_correct() {
420 let period = 10;
421 let lr = LinearRegression::new(period);
422
423 let n = period as f64;
424 let expected_x_sum = 0.5 * n * (n + 1.0);
425 let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
426 let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
427
428 assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
429 assert!(
430 (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
431 "x_mul_sum mismatch"
432 );
433 assert!(
434 (lr.divisor - expected_divisor).abs() < 1e-12,
435 "divisor mismatch"
436 );
437 }
438
439 #[rstest]
440 fn cached_constants_immutable_through_updates() {
441 let mut lr = LinearRegression::new(5);
442
443 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
444
445 for v in 0..20 {
446 lr.update_raw(f64::from(v));
447 }
448
449 assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
450 assert_eq!(
451 lr.x_mul_sum, x_mul_sum,
452 "x_mul_sum must remain unchanged after updates"
453 );
454 assert_eq!(
455 lr.divisor, divisor,
456 "divisor must remain unchanged after updates"
457 );
458 }
459
460 #[rstest]
461 fn cached_constants_immutable_after_reset() {
462 let mut lr = LinearRegression::new(8);
463
464 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
465
466 for v in 0..8 {
467 lr.update_raw(f64::from(v));
468 }
469 lr.reset();
470
471 assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
472 assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
473 assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
474 }
475
476 const EPS: f64 = 1e-12;
477
478 #[rstest]
479 #[should_panic]
480 fn new_zero_period_panics() {
481 let _ = LinearRegression::new(0);
482 }
483
484 #[rstest]
485 #[should_panic]
486 fn new_period_exceeds_max_panics() {
487 let _ = LinearRegression::new(MAX_PERIOD + 1);
488 }
489
490 #[rstest(
491 period, value,
492 case(8, 5.0),
493 case(16, -std::f64::consts::PI)
494 )]
495 fn constant_non_zero_series(period: usize, value: f64) {
496 let mut lr = LinearRegression::new(period);
497
498 for _ in 0..period {
499 lr.update_raw(value);
500 }
501
502 assert!(lr.initialized());
503 assert!(lr.slope.abs() < EPS);
504 assert!((lr.intercept - value).abs() < EPS);
505 assert!(lr.degree.abs() < EPS);
506 assert!(lr.r2.is_nan());
507 assert!((lr.cfo).abs() < EPS);
508 assert!((lr.value - value).abs() < EPS);
509 }
510
511 #[rstest(period, case(4), case(32))]
512 fn constant_zero_series_cfo_nan(period: usize) {
513 let mut lr = LinearRegression::new(period);
514
515 for _ in 0..period {
516 lr.update_raw(0.0);
517 }
518
519 assert!(lr.initialized());
520 assert!(lr.cfo.is_nan());
521 }
522
523 #[rstest(period, case(6), case(13))]
524 fn reset_clears_state_but_keeps_constants(period: usize) {
525 let mut lr = LinearRegression::new(period);
526
527 for i in 1..=period {
528 lr.update_raw(i as f64);
529 }
530
531 let x_sum_before = lr.x_sum;
532 let x_mul_sum_before = lr.x_mul_sum;
533 let divisor_before = lr.divisor;
534
535 lr.reset();
536
537 assert!(!lr.initialized());
538 assert!(!lr.has_inputs());
539
540 assert!(lr.slope.abs() < EPS);
541 assert!(lr.intercept.abs() < EPS);
542 assert!(lr.degree.abs() < EPS);
543 assert!(lr.cfo.abs() < EPS);
544 assert!(lr.r2.abs() < EPS);
545 assert!(lr.value.abs() < EPS);
546
547 assert_eq!(lr.x_sum, x_sum_before);
548 assert_eq!(lr.x_mul_sum, x_mul_sum_before);
549 assert_eq!(lr.divisor, divisor_before);
550 }
551
552 #[rstest(period, case(5), case(31))]
553 fn perfect_linear_series(period: usize) {
554 const A: f64 = 2.0;
555 const B: f64 = -3.0;
556 let mut lr = LinearRegression::new(period);
557
558 for x in 1..=period {
559 lr.update_raw(A.mul_add(x as f64, B));
560 }
561
562 assert!(lr.initialized());
563 assert!((lr.slope - A).abs() < EPS);
564 assert!((lr.intercept - B).abs() < EPS);
565 assert!((lr.r2 - 1.0).abs() < EPS);
566 assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
567 }
568
569 #[rstest]
570 fn sliding_window_keeps_last_period() {
571 const P: usize = 4;
572 let mut lr = LinearRegression::new(P);
573 for i in 1..=P {
574 lr.update_raw(i as f64);
575 }
576 let slope_first_window = lr.slope;
577
578 lr.update_raw(-100.0);
579 assert!(lr.slope < slope_first_window);
580 assert_eq!(lr.inputs.len(), P);
581 assert_eq!(lr.inputs.front(), Some(&2.0));
582 }
583
584 #[rstest]
585 fn r2_between_zero_and_one() {
586 const P: usize = 32;
587 let mut lr = LinearRegression::new(P);
588 for x in 1..=P {
589 let noise = if x.is_multiple_of(2) { 0.5 } else { -0.5 };
590 lr.update_raw(3.0f64.mul_add(x as f64, noise));
591 }
592 assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
593 }
594
595 #[rstest]
596 fn reset_before_initialized() {
597 let mut lr = LinearRegression::new(10);
598 lr.update_raw(1.0);
599 lr.reset();
600
601 assert!(!lr.initialized());
602 assert!(!lr.has_inputs());
603 assert_eq!(lr.inputs.len(), 0);
604 }
605}