1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20use strum::{AsRefStr, Display as StrumDisplay, EnumIter, EnumString, FromRepr};
21
22use crate::{
23 average::{MovingAverageFactory, MovingAverageType},
24 indicator::{Indicator, MovingAverage},
25};
26
27const MAX_PERIOD: usize = 1_024;
28
29#[repr(C)]
39#[derive(
40 Copy,
41 Clone,
42 Debug,
43 Default,
44 Hash,
45 PartialEq,
46 Eq,
47 PartialOrd,
48 Ord,
49 AsRefStr,
50 FromRepr,
51 EnumIter,
52 EnumString,
53 StrumDisplay,
54)]
55#[strum(ascii_case_insensitive)]
56#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
57#[cfg_attr(
58 feature = "python",
59 pyo3::pyclass(
60 frozen,
61 eq,
62 eq_int,
63 hash,
64 module = "nautilus_trader.core.nautilus_pyo3.indicators",
65 from_py_object,
66 )
67)]
68pub enum StochasticsDMethod {
69 #[default]
72 Ratio,
73 MovingAverage,
76}
77
78#[repr(C)]
79#[cfg_attr(
80 feature = "python",
81 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
82)]
83pub struct Stochastics {
84 pub period_k: usize,
86 pub period_d: usize,
88 pub slowing: usize,
90 pub ma_type: MovingAverageType,
92 pub d_method: StochasticsDMethod,
94 pub value_k: f64,
96 pub value_d: f64,
98 pub initialized: bool,
100 has_inputs: bool,
101 highs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
102 lows: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
103 c_sub_1: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
104 h_sub_l: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
105 slowing_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
107 d_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
109}
110
111impl Debug for Stochastics {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct(stringify!(Stochastics))
114 .field("period_k", &self.period_k)
115 .field("period_d", &self.period_d)
116 .field("slowing", &self.slowing)
117 .field("ma_type", &self.ma_type)
118 .field("d_method", &self.d_method)
119 .field("value_k", &self.value_k)
120 .field("value_d", &self.value_d)
121 .field("initialized", &self.initialized)
122 .field("has_inputs", &self.has_inputs)
123 .field(
124 "slowing_ma",
125 &self.slowing_ma.as_ref().map(|_| "MovingAverage"),
126 )
127 .field("d_ma", &self.d_ma.as_ref().map(|_| "MovingAverage"))
128 .finish()
129 }
130}
131
132impl Display for Stochastics {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 write!(f, "{}({},{})", self.name(), self.period_k, self.period_d,)
135 }
136}
137
138impl Indicator for Stochastics {
139 fn name(&self) -> String {
140 stringify!(Stochastics).to_string()
141 }
142
143 fn has_inputs(&self) -> bool {
144 self.has_inputs
145 }
146
147 fn initialized(&self) -> bool {
148 self.initialized
149 }
150
151 fn handle_bar(&mut self, bar: &Bar) {
152 self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
153 }
154
155 fn reset(&mut self) {
156 self.highs.clear();
157 self.lows.clear();
158 self.c_sub_1.clear();
159 self.h_sub_l.clear();
160 self.value_k = 0.0;
161 self.value_d = 0.0;
162 self.has_inputs = false;
163 self.initialized = false;
164
165 if let Some(ref mut ma) = self.slowing_ma {
167 ma.reset();
168 }
169
170 if let Some(ref mut ma) = self.d_ma {
172 ma.reset();
173 }
174 }
175}
176
177impl Stochastics {
178 #[must_use]
191 pub fn new(period_k: usize, period_d: usize) -> Self {
192 Self::new_with_params(
193 period_k,
194 period_d,
195 1, MovingAverageType::Exponential, StochasticsDMethod::Ratio, )
199 }
200
201 #[must_use]
216 pub fn new_with_params(
217 period_k: usize,
218 period_d: usize,
219 slowing: usize,
220 ma_type: MovingAverageType,
221 d_method: StochasticsDMethod,
222 ) -> Self {
223 assert!(
224 period_k > 0 && period_k <= MAX_PERIOD,
225 "Stochastics: period_k {period_k} exceeds bounds (1..={MAX_PERIOD})"
226 );
227 assert!(
228 period_d > 0 && period_d <= MAX_PERIOD,
229 "Stochastics: period_d {period_d} exceeds bounds (1..={MAX_PERIOD})"
230 );
231 assert!(
232 slowing > 0 && slowing <= MAX_PERIOD,
233 "Stochastics: slowing {slowing} exceeds bounds (1..={MAX_PERIOD})"
234 );
235
236 let slowing_ma = if slowing > 1 {
238 Some(MovingAverageFactory::create(ma_type, slowing))
239 } else {
240 None
241 };
242
243 let d_ma = match d_method {
245 StochasticsDMethod::MovingAverage => {
246 Some(MovingAverageFactory::create(ma_type, period_d))
247 }
248 StochasticsDMethod::Ratio => None,
249 };
250
251 Self {
252 period_k,
253 period_d,
254 slowing,
255 ma_type,
256 d_method,
257 has_inputs: false,
258 initialized: false,
259 value_k: 0.0,
260 value_d: 0.0,
261 highs: ArrayDeque::new(),
262 lows: ArrayDeque::new(),
263 h_sub_l: ArrayDeque::new(),
264 c_sub_1: ArrayDeque::new(),
265 slowing_ma,
266 d_ma,
267 }
268 }
269
270 pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
278 if !self.has_inputs {
279 self.has_inputs = true;
280 }
281
282 if self.highs.len() == self.period_k {
284 self.highs.pop_front();
285 self.lows.pop_front();
286 }
287 let _ = self.highs.push_back(high);
288 let _ = self.lows.push_back(low);
289
290 if !self.initialized
292 && self.highs.len() == self.period_k
293 && self.lows.len() == self.period_k
294 {
295 if self.slowing_ma.is_none() && self.d_method == StochasticsDMethod::Ratio {
298 self.initialized = true;
299 }
300 }
301
302 let k_max_high = self.highs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
304 let k_min_low = self.lows.iter().copied().fold(f64::INFINITY, f64::min);
305
306 if self.d_method == StochasticsDMethod::Ratio {
308 if self.c_sub_1.len() == self.period_d {
309 self.c_sub_1.pop_front();
310 self.h_sub_l.pop_front();
311 }
312 let _ = self.c_sub_1.push_back(close - k_min_low);
313 let _ = self.h_sub_l.push_back(k_max_high - k_min_low);
314 }
315
316 if k_max_high == k_min_low {
318 return;
319 }
320
321 let raw_k = 100.0 * ((close - k_min_low) / (k_max_high - k_min_low));
323
324 let slowed_k = match &mut self.slowing_ma {
326 Some(ma) => {
327 ma.update_raw(raw_k);
328 ma.value()
329 }
330 None => raw_k, };
332 self.value_k = slowed_k;
333
334 self.value_d = match self.d_method {
336 StochasticsDMethod::Ratio => {
337 let sum_h_sub_l: f64 = self.h_sub_l.iter().sum();
340 if sum_h_sub_l == 0.0 {
341 0.0
342 } else {
343 100.0 * (self.c_sub_1.iter().sum::<f64>() / sum_h_sub_l)
344 }
345 }
346 StochasticsDMethod::MovingAverage => {
347 if let Some(ref mut ma) = self.d_ma {
349 ma.update_raw(slowed_k);
350 ma.value()
351 } else {
352 50.0 }
354 }
355 };
356
357 if !self.initialized {
361 let base_ready = self.highs.len() == self.period_k;
362 let slowing_ready = match &self.slowing_ma {
363 Some(ma) => ma.initialized(),
364 None => true,
365 };
366 let d_ready = match self.d_method {
367 StochasticsDMethod::Ratio => true, StochasticsDMethod::MovingAverage => match &self.d_ma {
369 Some(ma) => ma.initialized(),
370 None => true,
371 },
372 };
373
374 if base_ready && slowing_ready && d_ready {
375 self.initialized = true;
376 }
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use nautilus_model::data::Bar;
384 use rstest::rstest;
385
386 use crate::{
387 average::MovingAverageType,
388 indicator::Indicator,
389 momentum::stochastics::{Stochastics, StochasticsDMethod},
390 stubs::{bar_ethusdt_binance_minute_bid, stochastics_10},
391 };
392
393 #[rstest]
394 fn test_stochastics_initialized(stochastics_10: Stochastics) {
395 let display_str = format!("{stochastics_10}");
396 assert_eq!(display_str, "Stochastics(10,10)");
397 assert_eq!(stochastics_10.period_d, 10);
398 assert_eq!(stochastics_10.period_k, 10);
399 assert!(!stochastics_10.initialized);
400 assert!(!stochastics_10.has_inputs);
401 }
402
403 #[rstest]
404 fn test_value_with_one_input(mut stochastics_10: Stochastics) {
405 stochastics_10.update_raw(1.0, 1.0, 1.0);
406 assert_eq!(stochastics_10.value_d, 0.0);
407 assert_eq!(stochastics_10.value_k, 0.0);
408 }
409
410 #[rstest]
411 fn test_value_with_three_inputs(mut stochastics_10: Stochastics) {
412 stochastics_10.update_raw(1.0, 1.0, 1.0);
413 stochastics_10.update_raw(2.0, 2.0, 2.0);
414 stochastics_10.update_raw(3.0, 3.0, 3.0);
415 assert_eq!(stochastics_10.value_d, 100.0);
416 assert_eq!(stochastics_10.value_k, 100.0);
417 }
418
419 #[rstest]
420 fn test_value_with_ten_inputs(mut stochastics_10: Stochastics) {
421 let high_values = [
422 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
423 ];
424 let low_values = [
425 0.9, 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.1, 10.2, 10.3, 11.1, 11.4,
426 ];
427 let close_values = [
428 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
429 ];
430
431 for i in 0..15 {
432 stochastics_10.update_raw(high_values[i], low_values[i], close_values[i]);
433 }
434
435 assert!(stochastics_10.initialized());
436 assert_eq!(stochastics_10.value_d, 100.0);
437 assert_eq!(stochastics_10.value_k, 100.0);
438 }
439
440 #[rstest]
441 fn test_initialized_with_required_input(mut stochastics_10: Stochastics) {
442 for i in 1..10 {
443 stochastics_10.update_raw(f64::from(i), f64::from(i), f64::from(i));
444 }
445 assert!(!stochastics_10.initialized);
446 stochastics_10.update_raw(10.0, 12.0, 14.0);
447 assert!(stochastics_10.initialized);
448 }
449
450 #[rstest]
451 fn test_handle_bar(mut stochastics_10: Stochastics, bar_ethusdt_binance_minute_bid: Bar) {
452 stochastics_10.handle_bar(&bar_ethusdt_binance_minute_bid);
453 assert_eq!(stochastics_10.value_d, 49.090_909_090_909_09);
454 assert_eq!(stochastics_10.value_k, 49.090_909_090_909_09);
455 assert!(stochastics_10.has_inputs);
456 assert!(!stochastics_10.initialized);
457 }
458
459 #[rstest]
460 fn test_reset(mut stochastics_10: Stochastics) {
461 stochastics_10.update_raw(1.0, 1.0, 1.0);
462 assert_eq!(stochastics_10.c_sub_1.len(), 1);
463 assert_eq!(stochastics_10.h_sub_l.len(), 1);
464
465 stochastics_10.reset();
466 assert_eq!(stochastics_10.value_d, 0.0);
467 assert_eq!(stochastics_10.value_k, 0.0);
468 assert_eq!(stochastics_10.h_sub_l.len(), 0);
469 assert_eq!(stochastics_10.c_sub_1.len(), 0);
470 assert!(!stochastics_10.has_inputs);
471 assert!(!stochastics_10.initialized);
472 }
473
474 #[rstest]
475 fn test_new_defaults_slowing_1_ratio() {
476 let stoch = Stochastics::new(10, 3);
477 assert_eq!(stoch.period_k, 10);
478 assert_eq!(stoch.period_d, 3);
479 assert_eq!(stoch.slowing, 1);
480 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
481 assert_eq!(stoch.d_method, StochasticsDMethod::Ratio);
482 assert!(
483 stoch.slowing_ma.is_none(),
484 "slowing_ma should be None when slowing == 1"
485 );
486 assert!(
487 stoch.d_ma.is_none(),
488 "d_ma should be None when d_method == Ratio"
489 );
490 }
491
492 #[rstest]
493 fn test_new_with_params_accepts_all_params() {
494 let stoch = Stochastics::new_with_params(
495 11,
496 3,
497 3,
498 MovingAverageType::Exponential,
499 StochasticsDMethod::MovingAverage,
500 );
501 assert_eq!(stoch.period_k, 11);
502 assert_eq!(stoch.period_d, 3);
503 assert_eq!(stoch.slowing, 3);
504 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
505 assert_eq!(stoch.d_method, StochasticsDMethod::MovingAverage);
506 assert!(
507 stoch.slowing_ma.is_some(),
508 "slowing_ma should exist when slowing > 1"
509 );
510 assert!(
511 stoch.d_ma.is_some(),
512 "d_ma should exist when d_method == MovingAverage"
513 );
514 }
515
516 #[rstest]
517 fn test_backward_compatibility_identical_output() {
518 let mut stoch_old = Stochastics::new(10, 10);
520 let mut stoch_new = Stochastics::new_with_params(
521 10,
522 10,
523 1,
524 MovingAverageType::Exponential,
525 StochasticsDMethod::Ratio,
526 );
527
528 let high_values = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
530 let low_values = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
531 let close_values = [0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
532
533 for i in 0..10 {
534 stoch_old.update_raw(high_values[i], low_values[i], close_values[i]);
535 stoch_new.update_raw(high_values[i], low_values[i], close_values[i]);
536 }
537
538 assert_eq!(stoch_old.value_k, stoch_new.value_k, "value_k mismatch");
540 assert_eq!(stoch_old.value_d, stoch_new.value_d, "value_d mismatch");
541 assert_eq!(stoch_old.initialized, stoch_new.initialized);
542 }
543
544 #[rstest]
545 fn test_slowing_3_smoothes_k() {
546 let mut stoch_no_slowing = Stochastics::new(5, 3);
547 let mut stoch_with_slowing = Stochastics::new_with_params(
548 5,
549 3,
550 3,
551 MovingAverageType::Exponential,
552 StochasticsDMethod::Ratio,
553 );
554
555 let data = [
557 (10.0, 5.0, 8.0),
558 (12.0, 6.0, 7.0),
559 (11.0, 4.0, 9.0),
560 (13.0, 7.0, 8.0),
561 (14.0, 8.0, 10.0),
562 (12.0, 6.0, 7.0),
563 (15.0, 9.0, 14.0),
564 (16.0, 10.0, 11.0),
565 ];
566
567 for (high, low, close) in data {
568 stoch_no_slowing.update_raw(high, low, close);
569 stoch_with_slowing.update_raw(high, low, close);
570 }
571
572 assert!(
576 (stoch_no_slowing.value_k - stoch_with_slowing.value_k).abs() > 0.01,
577 "Slowing should produce different %K values"
578 );
579 }
580
581 #[rstest]
582 #[case(MovingAverageType::Simple)]
583 #[case(MovingAverageType::Exponential)]
584 #[case(MovingAverageType::Wilder)]
585 #[case(MovingAverageType::Hull)]
586 fn test_slowing_with_different_ma_types(#[case] ma_type: MovingAverageType) {
587 let mut stoch = Stochastics::new_with_params(5, 3, 3, ma_type, StochasticsDMethod::Ratio);
588
589 for i in 1..=10 {
591 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
592 }
593
594 assert!(
595 stoch.value_k.is_finite(),
596 "value_k should be finite with {ma_type:?}"
597 );
598 assert!(
599 stoch.value_d.is_finite(),
600 "value_d should be finite with {ma_type:?}"
601 );
602 assert!(
603 stoch.value_k >= 0.0 && stoch.value_k <= 100.0,
604 "value_k out of range with {ma_type:?}"
605 );
606 }
607
608 #[rstest]
609 fn test_d_method_ratio_preserves_nautilus_behavior() {
610 let mut stoch = Stochastics::new_with_params(
611 10,
612 3,
613 1, MovingAverageType::Exponential,
615 StochasticsDMethod::Ratio,
616 );
617
618 for i in 1..=15 {
620 stoch.update_raw(f64::from(i), f64::from(i) - 0.1, f64::from(i));
621 }
622
623 assert!(stoch.initialized);
625 assert!(stoch.value_d > 0.0);
626 }
627
628 #[rstest]
629 fn test_d_method_ma_produces_smoothed_k() {
630 let mut stoch = Stochastics::new_with_params(
631 5,
632 3,
633 3, MovingAverageType::Exponential,
635 StochasticsDMethod::MovingAverage, );
637
638 let data = [
639 (10.0, 5.0, 8.0),
640 (12.0, 6.0, 7.0),
641 (11.0, 4.0, 9.0),
642 (13.0, 7.0, 8.0),
643 (14.0, 8.0, 10.0),
644 (12.0, 6.0, 7.0),
645 (15.0, 9.0, 14.0),
646 (16.0, 10.0, 11.0),
647 (14.0, 8.0, 12.0),
648 (13.0, 7.0, 10.0),
649 ];
650
651 for (high, low, close) in data {
652 stoch.update_raw(high, low, close);
653 }
654
655 assert!(stoch.value_d.is_finite());
657 assert!(stoch.value_d >= 0.0 && stoch.value_d <= 100.0);
658 }
659
660 #[rstest]
661 fn test_warmup_period_with_slowing() {
662 let mut stoch = Stochastics::new_with_params(
663 5,
664 3,
665 3, MovingAverageType::Exponential,
667 StochasticsDMethod::Ratio,
668 );
669
670 for i in 1..=4 {
677 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
678 assert!(!stoch.initialized, "Should not be initialized at bar {i}");
679 }
680
681 for i in 5..=15 {
683 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
684 }
685
686 assert!(
687 stoch.initialized,
688 "Should be initialized after sufficient bars"
689 );
690 }
691
692 #[rstest]
693 fn test_warmup_period_with_ma_d_method() {
694 let mut stoch = Stochastics::new_with_params(
695 5,
696 3,
697 3,
698 MovingAverageType::Exponential,
699 StochasticsDMethod::MovingAverage, );
701
702 for i in 1..=4 {
703 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
704 }
705 assert!(!stoch.initialized);
706
707 for i in 5..=20 {
709 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
710 }
711
712 assert!(
713 stoch.initialized,
714 "Should be initialized after sufficient bars"
715 );
716 }
717
718 #[rstest]
719 fn test_reset_clears_slowing_ma_state() {
720 let mut stoch = Stochastics::new_with_params(
721 5,
722 3,
723 3,
724 MovingAverageType::Exponential,
725 StochasticsDMethod::MovingAverage,
726 );
727
728 for i in 1..=10 {
730 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
731 }
732
733 assert!(stoch.has_inputs);
734
735 stoch.reset();
737
738 assert!(!stoch.has_inputs);
739 assert!(!stoch.initialized);
740 assert_eq!(stoch.value_k, 0.0);
741 assert_eq!(stoch.value_d, 0.0);
742 assert_eq!(stoch.highs.len(), 0);
743 assert_eq!(stoch.lows.len(), 0);
744
745 for i in 1..=10 {
747 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
748 }
749 assert!(stoch.value_k > 0.0);
750 }
751
752 #[rstest]
753 fn test_slowing_1_bypasses_ma() {
754 let stoch = Stochastics::new_with_params(
755 10,
756 3,
757 1, MovingAverageType::Exponential,
759 StochasticsDMethod::Ratio,
760 );
761
762 assert!(
763 stoch.slowing_ma.is_none(),
764 "slowing = 1 should not create MA"
765 );
766 }
767
768 #[rstest]
769 #[should_panic(expected = "slowing")]
770 fn test_slowing_0_panics() {
771 let _ = Stochastics::new_with_params(
772 10,
773 3,
774 0, MovingAverageType::Exponential,
776 StochasticsDMethod::Ratio,
777 );
778 }
779
780 #[rstest]
781 fn test_division_by_zero_protection() {
782 let mut stoch = Stochastics::new_with_params(
783 5,
784 3,
785 3,
786 MovingAverageType::Exponential,
787 StochasticsDMethod::MovingAverage,
788 );
789
790 for _ in 0..10 {
792 stoch.update_raw(100.0, 100.0, 100.0);
793 }
794
795 assert!(stoch.value_k.is_finite());
797 assert!(stoch.value_d.is_finite());
798 }
799}