1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20use strum::{AsRefStr, Display as StrumDisplay, EnumIter, EnumString, FromRepr};
21
22use crate::{
23 average::{MovingAverageFactory, MovingAverageType},
24 indicator::{Indicator, MovingAverage},
25};
26
27const MAX_PERIOD: usize = 1_024;
28
29#[repr(C)]
39#[derive(
40 Copy,
41 Clone,
42 Debug,
43 Default,
44 Hash,
45 PartialEq,
46 Eq,
47 PartialOrd,
48 Ord,
49 AsRefStr,
50 FromRepr,
51 EnumIter,
52 EnumString,
53 StrumDisplay,
54)]
55#[strum(ascii_case_insensitive)]
56#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
57#[cfg_attr(
58 feature = "python",
59 pyo3::pyclass(
60 frozen,
61 eq,
62 eq_int,
63 hash,
64 module = "nautilus_trader.core.nautilus_pyo3.indicators"
65 )
66)]
67pub enum StochasticsDMethod {
68 #[default]
71 Ratio,
72 MovingAverage,
75}
76
77#[repr(C)]
78#[cfg_attr(
79 feature = "python",
80 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
81)]
82pub struct Stochastics {
83 pub period_k: usize,
85 pub period_d: usize,
87 pub slowing: usize,
89 pub ma_type: MovingAverageType,
91 pub d_method: StochasticsDMethod,
93 pub value_k: f64,
95 pub value_d: f64,
97 pub initialized: bool,
99 has_inputs: bool,
100 highs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
101 lows: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
102 c_sub_1: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
103 h_sub_l: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
104 slowing_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
106 d_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
108}
109
110impl std::fmt::Debug for Stochastics {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("Stochastics")
113 .field("period_k", &self.period_k)
114 .field("period_d", &self.period_d)
115 .field("slowing", &self.slowing)
116 .field("ma_type", &self.ma_type)
117 .field("d_method", &self.d_method)
118 .field("value_k", &self.value_k)
119 .field("value_d", &self.value_d)
120 .field("initialized", &self.initialized)
121 .field("has_inputs", &self.has_inputs)
122 .field(
123 "slowing_ma",
124 &self.slowing_ma.as_ref().map(|_| "MovingAverage"),
125 )
126 .field("d_ma", &self.d_ma.as_ref().map(|_| "MovingAverage"))
127 .finish()
128 }
129}
130
131impl Display for Stochastics {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "{}({},{})", self.name(), self.period_k, self.period_d,)
134 }
135}
136
137impl Indicator for Stochastics {
138 fn name(&self) -> String {
139 stringify!(Stochastics).to_string()
140 }
141
142 fn has_inputs(&self) -> bool {
143 self.has_inputs
144 }
145
146 fn initialized(&self) -> bool {
147 self.initialized
148 }
149
150 fn handle_bar(&mut self, bar: &Bar) {
151 self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
152 }
153
154 fn reset(&mut self) {
155 self.highs.clear();
156 self.lows.clear();
157 self.c_sub_1.clear();
158 self.h_sub_l.clear();
159 self.value_k = 0.0;
160 self.value_d = 0.0;
161 self.has_inputs = false;
162 self.initialized = false;
163
164 if let Some(ref mut ma) = self.slowing_ma {
166 ma.reset();
167 }
168
169 if let Some(ref mut ma) = self.d_ma {
171 ma.reset();
172 }
173 }
174}
175
176impl Stochastics {
177 #[must_use]
190 pub fn new(period_k: usize, period_d: usize) -> Self {
191 Self::new_with_params(
192 period_k,
193 period_d,
194 1, MovingAverageType::Exponential, StochasticsDMethod::Ratio, )
198 }
199
200 #[must_use]
215 pub fn new_with_params(
216 period_k: usize,
217 period_d: usize,
218 slowing: usize,
219 ma_type: MovingAverageType,
220 d_method: StochasticsDMethod,
221 ) -> Self {
222 assert!(
223 period_k > 0 && period_k <= MAX_PERIOD,
224 "Stochastics: period_k {period_k} exceeds bounds (1..={MAX_PERIOD})"
225 );
226 assert!(
227 period_d > 0 && period_d <= MAX_PERIOD,
228 "Stochastics: period_d {period_d} exceeds bounds (1..={MAX_PERIOD})"
229 );
230 assert!(
231 slowing > 0 && slowing <= MAX_PERIOD,
232 "Stochastics: slowing {slowing} exceeds bounds (1..={MAX_PERIOD})"
233 );
234
235 let slowing_ma = if slowing > 1 {
237 Some(MovingAverageFactory::create(ma_type, slowing))
238 } else {
239 None
240 };
241
242 let d_ma = match d_method {
244 StochasticsDMethod::MovingAverage => {
245 Some(MovingAverageFactory::create(ma_type, period_d))
246 }
247 StochasticsDMethod::Ratio => None,
248 };
249
250 Self {
251 period_k,
252 period_d,
253 slowing,
254 ma_type,
255 d_method,
256 has_inputs: false,
257 initialized: false,
258 value_k: 0.0,
259 value_d: 0.0,
260 highs: ArrayDeque::new(),
261 lows: ArrayDeque::new(),
262 h_sub_l: ArrayDeque::new(),
263 c_sub_1: ArrayDeque::new(),
264 slowing_ma,
265 d_ma,
266 }
267 }
268
269 pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
277 if !self.has_inputs {
278 self.has_inputs = true;
279 }
280
281 if self.highs.len() == self.period_k {
283 self.highs.pop_front();
284 self.lows.pop_front();
285 }
286 let _ = self.highs.push_back(high);
287 let _ = self.lows.push_back(low);
288
289 if !self.initialized
291 && self.highs.len() == self.period_k
292 && self.lows.len() == self.period_k
293 {
294 if self.slowing_ma.is_none() && self.d_method == StochasticsDMethod::Ratio {
297 self.initialized = true;
298 }
299 }
300
301 let k_max_high = self.highs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
303 let k_min_low = self.lows.iter().copied().fold(f64::INFINITY, f64::min);
304
305 if self.d_method == StochasticsDMethod::Ratio {
307 if self.c_sub_1.len() == self.period_d {
308 self.c_sub_1.pop_front();
309 self.h_sub_l.pop_front();
310 }
311 let _ = self.c_sub_1.push_back(close - k_min_low);
312 let _ = self.h_sub_l.push_back(k_max_high - k_min_low);
313 }
314
315 if k_max_high == k_min_low {
317 return;
318 }
319
320 let raw_k = 100.0 * ((close - k_min_low) / (k_max_high - k_min_low));
322
323 let slowed_k = match &mut self.slowing_ma {
325 Some(ma) => {
326 ma.update_raw(raw_k);
327 ma.value()
328 }
329 None => raw_k, };
331 self.value_k = slowed_k;
332
333 self.value_d = match self.d_method {
335 StochasticsDMethod::Ratio => {
336 let sum_h_sub_l: f64 = self.h_sub_l.iter().sum();
339 if sum_h_sub_l == 0.0 {
340 0.0
341 } else {
342 100.0 * (self.c_sub_1.iter().sum::<f64>() / sum_h_sub_l)
343 }
344 }
345 StochasticsDMethod::MovingAverage => {
346 if let Some(ref mut ma) = self.d_ma {
348 ma.update_raw(slowed_k);
349 ma.value()
350 } else {
351 50.0 }
353 }
354 };
355
356 if !self.initialized {
360 let base_ready = self.highs.len() == self.period_k;
361 let slowing_ready = match &self.slowing_ma {
362 Some(ma) => ma.initialized(),
363 None => true,
364 };
365 let d_ready = match self.d_method {
366 StochasticsDMethod::Ratio => true, StochasticsDMethod::MovingAverage => match &self.d_ma {
368 Some(ma) => ma.initialized(),
369 None => true,
370 },
371 };
372
373 if base_ready && slowing_ready && d_ready {
374 self.initialized = true;
375 }
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use nautilus_model::data::Bar;
383 use rstest::rstest;
384
385 use crate::{
386 average::MovingAverageType,
387 indicator::Indicator,
388 momentum::stochastics::{Stochastics, StochasticsDMethod},
389 stubs::{bar_ethusdt_binance_minute_bid, stochastics_10},
390 };
391
392 #[rstest]
397 fn test_stochastics_initialized(stochastics_10: Stochastics) {
398 let display_str = format!("{stochastics_10}");
399 assert_eq!(display_str, "Stochastics(10,10)");
400 assert_eq!(stochastics_10.period_d, 10);
401 assert_eq!(stochastics_10.period_k, 10);
402 assert!(!stochastics_10.initialized);
403 assert!(!stochastics_10.has_inputs);
404 }
405
406 #[rstest]
407 fn test_value_with_one_input(mut stochastics_10: Stochastics) {
408 stochastics_10.update_raw(1.0, 1.0, 1.0);
409 assert_eq!(stochastics_10.value_d, 0.0);
410 assert_eq!(stochastics_10.value_k, 0.0);
411 }
412
413 #[rstest]
414 fn test_value_with_three_inputs(mut stochastics_10: Stochastics) {
415 stochastics_10.update_raw(1.0, 1.0, 1.0);
416 stochastics_10.update_raw(2.0, 2.0, 2.0);
417 stochastics_10.update_raw(3.0, 3.0, 3.0);
418 assert_eq!(stochastics_10.value_d, 100.0);
419 assert_eq!(stochastics_10.value_k, 100.0);
420 }
421
422 #[rstest]
423 fn test_value_with_ten_inputs(mut stochastics_10: Stochastics) {
424 let high_values = [
425 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
426 ];
427 let low_values = [
428 0.9, 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.1, 10.2, 10.3, 11.1, 11.4,
429 ];
430 let close_values = [
431 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
432 ];
433
434 for i in 0..15 {
435 stochastics_10.update_raw(high_values[i], low_values[i], close_values[i]);
436 }
437
438 assert!(stochastics_10.initialized());
439 assert_eq!(stochastics_10.value_d, 100.0);
440 assert_eq!(stochastics_10.value_k, 100.0);
441 }
442
443 #[rstest]
444 fn test_initialized_with_required_input(mut stochastics_10: Stochastics) {
445 for i in 1..10 {
446 stochastics_10.update_raw(f64::from(i), f64::from(i), f64::from(i));
447 }
448 assert!(!stochastics_10.initialized);
449 stochastics_10.update_raw(10.0, 12.0, 14.0);
450 assert!(stochastics_10.initialized);
451 }
452
453 #[rstest]
454 fn test_handle_bar(mut stochastics_10: Stochastics, bar_ethusdt_binance_minute_bid: Bar) {
455 stochastics_10.handle_bar(&bar_ethusdt_binance_minute_bid);
456 assert_eq!(stochastics_10.value_d, 49.090_909_090_909_09);
457 assert_eq!(stochastics_10.value_k, 49.090_909_090_909_09);
458 assert!(stochastics_10.has_inputs);
459 assert!(!stochastics_10.initialized);
460 }
461
462 #[rstest]
463 fn test_reset(mut stochastics_10: Stochastics) {
464 stochastics_10.update_raw(1.0, 1.0, 1.0);
465 assert_eq!(stochastics_10.c_sub_1.len(), 1);
466 assert_eq!(stochastics_10.h_sub_l.len(), 1);
467
468 stochastics_10.reset();
469 assert_eq!(stochastics_10.value_d, 0.0);
470 assert_eq!(stochastics_10.value_k, 0.0);
471 assert_eq!(stochastics_10.h_sub_l.len(), 0);
472 assert_eq!(stochastics_10.c_sub_1.len(), 0);
473 assert!(!stochastics_10.has_inputs);
474 assert!(!stochastics_10.initialized);
475 }
476
477 #[rstest]
482 fn test_new_defaults_slowing_1_ratio() {
483 let stoch = Stochastics::new(10, 3);
484 assert_eq!(stoch.period_k, 10);
485 assert_eq!(stoch.period_d, 3);
486 assert_eq!(stoch.slowing, 1);
487 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
488 assert_eq!(stoch.d_method, StochasticsDMethod::Ratio);
489 assert!(
490 stoch.slowing_ma.is_none(),
491 "slowing_ma should be None when slowing == 1"
492 );
493 assert!(
494 stoch.d_ma.is_none(),
495 "d_ma should be None when d_method == Ratio"
496 );
497 }
498
499 #[rstest]
500 fn test_new_with_params_accepts_all_params() {
501 let stoch = Stochastics::new_with_params(
502 11,
503 3,
504 3,
505 MovingAverageType::Exponential,
506 StochasticsDMethod::MovingAverage,
507 );
508 assert_eq!(stoch.period_k, 11);
509 assert_eq!(stoch.period_d, 3);
510 assert_eq!(stoch.slowing, 3);
511 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
512 assert_eq!(stoch.d_method, StochasticsDMethod::MovingAverage);
513 assert!(
514 stoch.slowing_ma.is_some(),
515 "slowing_ma should exist when slowing > 1"
516 );
517 assert!(
518 stoch.d_ma.is_some(),
519 "d_ma should exist when d_method == MovingAverage"
520 );
521 }
522
523 #[rstest]
524 fn test_backward_compatibility_identical_output() {
525 let mut stoch_old = Stochastics::new(10, 10);
527 let mut stoch_new = Stochastics::new_with_params(
528 10,
529 10,
530 1,
531 MovingAverageType::Exponential,
532 StochasticsDMethod::Ratio,
533 );
534
535 let high_values = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
537 let low_values = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
538 let close_values = [0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
539
540 for i in 0..10 {
541 stoch_old.update_raw(high_values[i], low_values[i], close_values[i]);
542 stoch_new.update_raw(high_values[i], low_values[i], close_values[i]);
543 }
544
545 assert_eq!(stoch_old.value_k, stoch_new.value_k, "value_k mismatch");
547 assert_eq!(stoch_old.value_d, stoch_new.value_d, "value_d mismatch");
548 assert_eq!(stoch_old.initialized, stoch_new.initialized);
549 }
550
551 #[rstest]
556 fn test_slowing_3_smoothes_k() {
557 let mut stoch_no_slowing = Stochastics::new(5, 3);
558 let mut stoch_with_slowing = Stochastics::new_with_params(
559 5,
560 3,
561 3,
562 MovingAverageType::Exponential,
563 StochasticsDMethod::Ratio,
564 );
565
566 let data = [
568 (10.0, 5.0, 8.0),
569 (12.0, 6.0, 7.0),
570 (11.0, 4.0, 9.0),
571 (13.0, 7.0, 8.0),
572 (14.0, 8.0, 10.0),
573 (12.0, 6.0, 7.0),
574 (15.0, 9.0, 14.0),
575 (16.0, 10.0, 11.0),
576 ];
577
578 for (high, low, close) in data {
579 stoch_no_slowing.update_raw(high, low, close);
580 stoch_with_slowing.update_raw(high, low, close);
581 }
582
583 assert!(
587 (stoch_no_slowing.value_k - stoch_with_slowing.value_k).abs() > 0.01,
588 "Slowing should produce different %K values"
589 );
590 }
591
592 #[rstest]
593 #[case(MovingAverageType::Simple)]
594 #[case(MovingAverageType::Exponential)]
595 #[case(MovingAverageType::Wilder)]
596 #[case(MovingAverageType::Hull)]
597 fn test_slowing_with_different_ma_types(#[case] ma_type: MovingAverageType) {
598 let mut stoch = Stochastics::new_with_params(5, 3, 3, ma_type, StochasticsDMethod::Ratio);
599
600 for i in 1..=10 {
602 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
603 }
604
605 assert!(
606 stoch.value_k.is_finite(),
607 "value_k should be finite with {ma_type:?}"
608 );
609 assert!(
610 stoch.value_d.is_finite(),
611 "value_d should be finite with {ma_type:?}"
612 );
613 assert!(
614 stoch.value_k >= 0.0 && stoch.value_k <= 100.0,
615 "value_k out of range with {ma_type:?}"
616 );
617 }
618
619 #[rstest]
624 fn test_d_method_ratio_preserves_nautilus_behavior() {
625 let mut stoch = Stochastics::new_with_params(
626 10,
627 3,
628 1, MovingAverageType::Exponential,
630 StochasticsDMethod::Ratio,
631 );
632
633 for i in 1..=15 {
635 stoch.update_raw(f64::from(i), f64::from(i) - 0.1, f64::from(i));
636 }
637
638 assert!(stoch.initialized);
640 assert!(stoch.value_d > 0.0);
641 }
642
643 #[rstest]
644 fn test_d_method_ma_produces_smoothed_k() {
645 let mut stoch = Stochastics::new_with_params(
646 5,
647 3,
648 3, MovingAverageType::Exponential,
650 StochasticsDMethod::MovingAverage, );
652
653 let data = [
654 (10.0, 5.0, 8.0),
655 (12.0, 6.0, 7.0),
656 (11.0, 4.0, 9.0),
657 (13.0, 7.0, 8.0),
658 (14.0, 8.0, 10.0),
659 (12.0, 6.0, 7.0),
660 (15.0, 9.0, 14.0),
661 (16.0, 10.0, 11.0),
662 (14.0, 8.0, 12.0),
663 (13.0, 7.0, 10.0),
664 ];
665
666 for (high, low, close) in data {
667 stoch.update_raw(high, low, close);
668 }
669
670 assert!(stoch.value_d.is_finite());
672 assert!(stoch.value_d >= 0.0 && stoch.value_d <= 100.0);
673 }
674
675 #[rstest]
680 fn test_warmup_period_with_slowing() {
681 let mut stoch = Stochastics::new_with_params(
682 5,
683 3,
684 3, MovingAverageType::Exponential,
686 StochasticsDMethod::Ratio,
687 );
688
689 for i in 1..=4 {
696 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
697 assert!(!stoch.initialized, "Should not be initialized at bar {i}");
698 }
699
700 for i in 5..=15 {
702 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
703 }
704
705 assert!(
706 stoch.initialized,
707 "Should be initialized after sufficient bars"
708 );
709 }
710
711 #[rstest]
712 fn test_warmup_period_with_ma_d_method() {
713 let mut stoch = Stochastics::new_with_params(
714 5,
715 3,
716 3,
717 MovingAverageType::Exponential,
718 StochasticsDMethod::MovingAverage, );
720
721 for i in 1..=4 {
722 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
723 }
724 assert!(!stoch.initialized);
725
726 for i in 5..=20 {
728 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
729 }
730
731 assert!(
732 stoch.initialized,
733 "Should be initialized after sufficient bars"
734 );
735 }
736
737 #[rstest]
742 fn test_reset_clears_slowing_ma_state() {
743 let mut stoch = Stochastics::new_with_params(
744 5,
745 3,
746 3,
747 MovingAverageType::Exponential,
748 StochasticsDMethod::MovingAverage,
749 );
750
751 for i in 1..=10 {
753 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
754 }
755
756 assert!(stoch.has_inputs);
757
758 stoch.reset();
760
761 assert!(!stoch.has_inputs);
762 assert!(!stoch.initialized);
763 assert_eq!(stoch.value_k, 0.0);
764 assert_eq!(stoch.value_d, 0.0);
765 assert_eq!(stoch.highs.len(), 0);
766 assert_eq!(stoch.lows.len(), 0);
767
768 for i in 1..=10 {
770 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
771 }
772 assert!(stoch.value_k > 0.0);
773 }
774
775 #[rstest]
780 fn test_slowing_1_bypasses_ma() {
781 let stoch = Stochastics::new_with_params(
782 10,
783 3,
784 1, MovingAverageType::Exponential,
786 StochasticsDMethod::Ratio,
787 );
788
789 assert!(
790 stoch.slowing_ma.is_none(),
791 "slowing = 1 should not create MA"
792 );
793 }
794
795 #[rstest]
796 #[should_panic(expected = "slowing")]
797 fn test_slowing_0_panics() {
798 let _ = Stochastics::new_with_params(
799 10,
800 3,
801 0, MovingAverageType::Exponential,
803 StochasticsDMethod::Ratio,
804 );
805 }
806
807 #[rstest]
808 fn test_division_by_zero_protection() {
809 let mut stoch = Stochastics::new_with_params(
810 5,
811 3,
812 3,
813 MovingAverageType::Exponential,
814 StochasticsDMethod::MovingAverage,
815 );
816
817 for _ in 0..10 {
819 stoch.update_raw(100.0, 100.0, 100.0);
820 }
821
822 assert!(stoch.value_k.is_finite());
824 assert!(stoch.value_d.is_finite());
825 }
826}