1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20use strum::{AsRefStr, Display as StrumDisplay, EnumIter, EnumString, FromRepr};
21
22use crate::{
23 average::{MovingAverageFactory, MovingAverageType},
24 indicator::{Indicator, MovingAverage},
25};
26
27const MAX_PERIOD: usize = 1_024;
28
29#[repr(C)]
39#[derive(
40 Copy,
41 Clone,
42 Debug,
43 Default,
44 Hash,
45 PartialEq,
46 Eq,
47 PartialOrd,
48 Ord,
49 AsRefStr,
50 FromRepr,
51 EnumIter,
52 EnumString,
53 StrumDisplay,
54)]
55#[strum(ascii_case_insensitive)]
56#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
57#[cfg_attr(
58 feature = "python",
59 pyo3::pyclass(
60 frozen,
61 eq,
62 eq_int,
63 hash,
64 module = "nautilus_trader.core.nautilus_pyo3.indicators"
65 )
66)]
67pub enum StochasticsDMethod {
68 #[default]
71 Ratio,
72 MovingAverage,
75}
76
77#[repr(C)]
78#[cfg_attr(
79 feature = "python",
80 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
81)]
82pub struct Stochastics {
83 pub period_k: usize,
85 pub period_d: usize,
87 pub slowing: usize,
89 pub ma_type: MovingAverageType,
91 pub d_method: StochasticsDMethod,
93 pub value_k: f64,
95 pub value_d: f64,
97 pub initialized: bool,
99 has_inputs: bool,
100 highs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
101 lows: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
102 c_sub_1: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
103 h_sub_l: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
104 slowing_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
106 d_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
108}
109
110impl Debug for Stochastics {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("Stochastics")
113 .field("period_k", &self.period_k)
114 .field("period_d", &self.period_d)
115 .field("slowing", &self.slowing)
116 .field("ma_type", &self.ma_type)
117 .field("d_method", &self.d_method)
118 .field("value_k", &self.value_k)
119 .field("value_d", &self.value_d)
120 .field("initialized", &self.initialized)
121 .field("has_inputs", &self.has_inputs)
122 .field(
123 "slowing_ma",
124 &self.slowing_ma.as_ref().map(|_| "MovingAverage"),
125 )
126 .field("d_ma", &self.d_ma.as_ref().map(|_| "MovingAverage"))
127 .finish()
128 }
129}
130
131impl Display for Stochastics {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "{}({},{})", self.name(), self.period_k, self.period_d,)
134 }
135}
136
137impl Indicator for Stochastics {
138 fn name(&self) -> String {
139 stringify!(Stochastics).to_string()
140 }
141
142 fn has_inputs(&self) -> bool {
143 self.has_inputs
144 }
145
146 fn initialized(&self) -> bool {
147 self.initialized
148 }
149
150 fn handle_bar(&mut self, bar: &Bar) {
151 self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
152 }
153
154 fn reset(&mut self) {
155 self.highs.clear();
156 self.lows.clear();
157 self.c_sub_1.clear();
158 self.h_sub_l.clear();
159 self.value_k = 0.0;
160 self.value_d = 0.0;
161 self.has_inputs = false;
162 self.initialized = false;
163
164 if let Some(ref mut ma) = self.slowing_ma {
166 ma.reset();
167 }
168
169 if let Some(ref mut ma) = self.d_ma {
171 ma.reset();
172 }
173 }
174}
175
176impl Stochastics {
177 #[must_use]
190 pub fn new(period_k: usize, period_d: usize) -> Self {
191 Self::new_with_params(
192 period_k,
193 period_d,
194 1, MovingAverageType::Exponential, StochasticsDMethod::Ratio, )
198 }
199
200 #[must_use]
215 pub fn new_with_params(
216 period_k: usize,
217 period_d: usize,
218 slowing: usize,
219 ma_type: MovingAverageType,
220 d_method: StochasticsDMethod,
221 ) -> Self {
222 assert!(
223 period_k > 0 && period_k <= MAX_PERIOD,
224 "Stochastics: period_k {period_k} exceeds bounds (1..={MAX_PERIOD})"
225 );
226 assert!(
227 period_d > 0 && period_d <= MAX_PERIOD,
228 "Stochastics: period_d {period_d} exceeds bounds (1..={MAX_PERIOD})"
229 );
230 assert!(
231 slowing > 0 && slowing <= MAX_PERIOD,
232 "Stochastics: slowing {slowing} exceeds bounds (1..={MAX_PERIOD})"
233 );
234
235 let slowing_ma = if slowing > 1 {
237 Some(MovingAverageFactory::create(ma_type, slowing))
238 } else {
239 None
240 };
241
242 let d_ma = match d_method {
244 StochasticsDMethod::MovingAverage => {
245 Some(MovingAverageFactory::create(ma_type, period_d))
246 }
247 StochasticsDMethod::Ratio => None,
248 };
249
250 Self {
251 period_k,
252 period_d,
253 slowing,
254 ma_type,
255 d_method,
256 has_inputs: false,
257 initialized: false,
258 value_k: 0.0,
259 value_d: 0.0,
260 highs: ArrayDeque::new(),
261 lows: ArrayDeque::new(),
262 h_sub_l: ArrayDeque::new(),
263 c_sub_1: ArrayDeque::new(),
264 slowing_ma,
265 d_ma,
266 }
267 }
268
269 pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
277 if !self.has_inputs {
278 self.has_inputs = true;
279 }
280
281 if self.highs.len() == self.period_k {
283 self.highs.pop_front();
284 self.lows.pop_front();
285 }
286 let _ = self.highs.push_back(high);
287 let _ = self.lows.push_back(low);
288
289 if !self.initialized
291 && self.highs.len() == self.period_k
292 && self.lows.len() == self.period_k
293 {
294 if self.slowing_ma.is_none() && self.d_method == StochasticsDMethod::Ratio {
297 self.initialized = true;
298 }
299 }
300
301 let k_max_high = self.highs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
303 let k_min_low = self.lows.iter().copied().fold(f64::INFINITY, f64::min);
304
305 if self.d_method == StochasticsDMethod::Ratio {
307 if self.c_sub_1.len() == self.period_d {
308 self.c_sub_1.pop_front();
309 self.h_sub_l.pop_front();
310 }
311 let _ = self.c_sub_1.push_back(close - k_min_low);
312 let _ = self.h_sub_l.push_back(k_max_high - k_min_low);
313 }
314
315 if k_max_high == k_min_low {
317 return;
318 }
319
320 let raw_k = 100.0 * ((close - k_min_low) / (k_max_high - k_min_low));
322
323 let slowed_k = match &mut self.slowing_ma {
325 Some(ma) => {
326 ma.update_raw(raw_k);
327 ma.value()
328 }
329 None => raw_k, };
331 self.value_k = slowed_k;
332
333 self.value_d = match self.d_method {
335 StochasticsDMethod::Ratio => {
336 let sum_h_sub_l: f64 = self.h_sub_l.iter().sum();
339 if sum_h_sub_l == 0.0 {
340 0.0
341 } else {
342 100.0 * (self.c_sub_1.iter().sum::<f64>() / sum_h_sub_l)
343 }
344 }
345 StochasticsDMethod::MovingAverage => {
346 if let Some(ref mut ma) = self.d_ma {
348 ma.update_raw(slowed_k);
349 ma.value()
350 } else {
351 50.0 }
353 }
354 };
355
356 if !self.initialized {
360 let base_ready = self.highs.len() == self.period_k;
361 let slowing_ready = match &self.slowing_ma {
362 Some(ma) => ma.initialized(),
363 None => true,
364 };
365 let d_ready = match self.d_method {
366 StochasticsDMethod::Ratio => true, StochasticsDMethod::MovingAverage => match &self.d_ma {
368 Some(ma) => ma.initialized(),
369 None => true,
370 },
371 };
372
373 if base_ready && slowing_ready && d_ready {
374 self.initialized = true;
375 }
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use nautilus_model::data::Bar;
383 use rstest::rstest;
384
385 use crate::{
386 average::MovingAverageType,
387 indicator::Indicator,
388 momentum::stochastics::{Stochastics, StochasticsDMethod},
389 stubs::{bar_ethusdt_binance_minute_bid, stochastics_10},
390 };
391
392 #[rstest]
393 fn test_stochastics_initialized(stochastics_10: Stochastics) {
394 let display_str = format!("{stochastics_10}");
395 assert_eq!(display_str, "Stochastics(10,10)");
396 assert_eq!(stochastics_10.period_d, 10);
397 assert_eq!(stochastics_10.period_k, 10);
398 assert!(!stochastics_10.initialized);
399 assert!(!stochastics_10.has_inputs);
400 }
401
402 #[rstest]
403 fn test_value_with_one_input(mut stochastics_10: Stochastics) {
404 stochastics_10.update_raw(1.0, 1.0, 1.0);
405 assert_eq!(stochastics_10.value_d, 0.0);
406 assert_eq!(stochastics_10.value_k, 0.0);
407 }
408
409 #[rstest]
410 fn test_value_with_three_inputs(mut stochastics_10: Stochastics) {
411 stochastics_10.update_raw(1.0, 1.0, 1.0);
412 stochastics_10.update_raw(2.0, 2.0, 2.0);
413 stochastics_10.update_raw(3.0, 3.0, 3.0);
414 assert_eq!(stochastics_10.value_d, 100.0);
415 assert_eq!(stochastics_10.value_k, 100.0);
416 }
417
418 #[rstest]
419 fn test_value_with_ten_inputs(mut stochastics_10: Stochastics) {
420 let high_values = [
421 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
422 ];
423 let low_values = [
424 0.9, 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.1, 10.2, 10.3, 11.1, 11.4,
425 ];
426 let close_values = [
427 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
428 ];
429
430 for i in 0..15 {
431 stochastics_10.update_raw(high_values[i], low_values[i], close_values[i]);
432 }
433
434 assert!(stochastics_10.initialized());
435 assert_eq!(stochastics_10.value_d, 100.0);
436 assert_eq!(stochastics_10.value_k, 100.0);
437 }
438
439 #[rstest]
440 fn test_initialized_with_required_input(mut stochastics_10: Stochastics) {
441 for i in 1..10 {
442 stochastics_10.update_raw(f64::from(i), f64::from(i), f64::from(i));
443 }
444 assert!(!stochastics_10.initialized);
445 stochastics_10.update_raw(10.0, 12.0, 14.0);
446 assert!(stochastics_10.initialized);
447 }
448
449 #[rstest]
450 fn test_handle_bar(mut stochastics_10: Stochastics, bar_ethusdt_binance_minute_bid: Bar) {
451 stochastics_10.handle_bar(&bar_ethusdt_binance_minute_bid);
452 assert_eq!(stochastics_10.value_d, 49.090_909_090_909_09);
453 assert_eq!(stochastics_10.value_k, 49.090_909_090_909_09);
454 assert!(stochastics_10.has_inputs);
455 assert!(!stochastics_10.initialized);
456 }
457
458 #[rstest]
459 fn test_reset(mut stochastics_10: Stochastics) {
460 stochastics_10.update_raw(1.0, 1.0, 1.0);
461 assert_eq!(stochastics_10.c_sub_1.len(), 1);
462 assert_eq!(stochastics_10.h_sub_l.len(), 1);
463
464 stochastics_10.reset();
465 assert_eq!(stochastics_10.value_d, 0.0);
466 assert_eq!(stochastics_10.value_k, 0.0);
467 assert_eq!(stochastics_10.h_sub_l.len(), 0);
468 assert_eq!(stochastics_10.c_sub_1.len(), 0);
469 assert!(!stochastics_10.has_inputs);
470 assert!(!stochastics_10.initialized);
471 }
472
473 #[rstest]
474 fn test_new_defaults_slowing_1_ratio() {
475 let stoch = Stochastics::new(10, 3);
476 assert_eq!(stoch.period_k, 10);
477 assert_eq!(stoch.period_d, 3);
478 assert_eq!(stoch.slowing, 1);
479 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
480 assert_eq!(stoch.d_method, StochasticsDMethod::Ratio);
481 assert!(
482 stoch.slowing_ma.is_none(),
483 "slowing_ma should be None when slowing == 1"
484 );
485 assert!(
486 stoch.d_ma.is_none(),
487 "d_ma should be None when d_method == Ratio"
488 );
489 }
490
491 #[rstest]
492 fn test_new_with_params_accepts_all_params() {
493 let stoch = Stochastics::new_with_params(
494 11,
495 3,
496 3,
497 MovingAverageType::Exponential,
498 StochasticsDMethod::MovingAverage,
499 );
500 assert_eq!(stoch.period_k, 11);
501 assert_eq!(stoch.period_d, 3);
502 assert_eq!(stoch.slowing, 3);
503 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
504 assert_eq!(stoch.d_method, StochasticsDMethod::MovingAverage);
505 assert!(
506 stoch.slowing_ma.is_some(),
507 "slowing_ma should exist when slowing > 1"
508 );
509 assert!(
510 stoch.d_ma.is_some(),
511 "d_ma should exist when d_method == MovingAverage"
512 );
513 }
514
515 #[rstest]
516 fn test_backward_compatibility_identical_output() {
517 let mut stoch_old = Stochastics::new(10, 10);
519 let mut stoch_new = Stochastics::new_with_params(
520 10,
521 10,
522 1,
523 MovingAverageType::Exponential,
524 StochasticsDMethod::Ratio,
525 );
526
527 let high_values = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
529 let low_values = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
530 let close_values = [0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
531
532 for i in 0..10 {
533 stoch_old.update_raw(high_values[i], low_values[i], close_values[i]);
534 stoch_new.update_raw(high_values[i], low_values[i], close_values[i]);
535 }
536
537 assert_eq!(stoch_old.value_k, stoch_new.value_k, "value_k mismatch");
539 assert_eq!(stoch_old.value_d, stoch_new.value_d, "value_d mismatch");
540 assert_eq!(stoch_old.initialized, stoch_new.initialized);
541 }
542
543 #[rstest]
544 fn test_slowing_3_smoothes_k() {
545 let mut stoch_no_slowing = Stochastics::new(5, 3);
546 let mut stoch_with_slowing = Stochastics::new_with_params(
547 5,
548 3,
549 3,
550 MovingAverageType::Exponential,
551 StochasticsDMethod::Ratio,
552 );
553
554 let data = [
556 (10.0, 5.0, 8.0),
557 (12.0, 6.0, 7.0),
558 (11.0, 4.0, 9.0),
559 (13.0, 7.0, 8.0),
560 (14.0, 8.0, 10.0),
561 (12.0, 6.0, 7.0),
562 (15.0, 9.0, 14.0),
563 (16.0, 10.0, 11.0),
564 ];
565
566 for (high, low, close) in data {
567 stoch_no_slowing.update_raw(high, low, close);
568 stoch_with_slowing.update_raw(high, low, close);
569 }
570
571 assert!(
575 (stoch_no_slowing.value_k - stoch_with_slowing.value_k).abs() > 0.01,
576 "Slowing should produce different %K values"
577 );
578 }
579
580 #[rstest]
581 #[case(MovingAverageType::Simple)]
582 #[case(MovingAverageType::Exponential)]
583 #[case(MovingAverageType::Wilder)]
584 #[case(MovingAverageType::Hull)]
585 fn test_slowing_with_different_ma_types(#[case] ma_type: MovingAverageType) {
586 let mut stoch = Stochastics::new_with_params(5, 3, 3, ma_type, StochasticsDMethod::Ratio);
587
588 for i in 1..=10 {
590 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
591 }
592
593 assert!(
594 stoch.value_k.is_finite(),
595 "value_k should be finite with {ma_type:?}"
596 );
597 assert!(
598 stoch.value_d.is_finite(),
599 "value_d should be finite with {ma_type:?}"
600 );
601 assert!(
602 stoch.value_k >= 0.0 && stoch.value_k <= 100.0,
603 "value_k out of range with {ma_type:?}"
604 );
605 }
606
607 #[rstest]
608 fn test_d_method_ratio_preserves_nautilus_behavior() {
609 let mut stoch = Stochastics::new_with_params(
610 10,
611 3,
612 1, MovingAverageType::Exponential,
614 StochasticsDMethod::Ratio,
615 );
616
617 for i in 1..=15 {
619 stoch.update_raw(f64::from(i), f64::from(i) - 0.1, f64::from(i));
620 }
621
622 assert!(stoch.initialized);
624 assert!(stoch.value_d > 0.0);
625 }
626
627 #[rstest]
628 fn test_d_method_ma_produces_smoothed_k() {
629 let mut stoch = Stochastics::new_with_params(
630 5,
631 3,
632 3, MovingAverageType::Exponential,
634 StochasticsDMethod::MovingAverage, );
636
637 let data = [
638 (10.0, 5.0, 8.0),
639 (12.0, 6.0, 7.0),
640 (11.0, 4.0, 9.0),
641 (13.0, 7.0, 8.0),
642 (14.0, 8.0, 10.0),
643 (12.0, 6.0, 7.0),
644 (15.0, 9.0, 14.0),
645 (16.0, 10.0, 11.0),
646 (14.0, 8.0, 12.0),
647 (13.0, 7.0, 10.0),
648 ];
649
650 for (high, low, close) in data {
651 stoch.update_raw(high, low, close);
652 }
653
654 assert!(stoch.value_d.is_finite());
656 assert!(stoch.value_d >= 0.0 && stoch.value_d <= 100.0);
657 }
658
659 #[rstest]
660 fn test_warmup_period_with_slowing() {
661 let mut stoch = Stochastics::new_with_params(
662 5,
663 3,
664 3, MovingAverageType::Exponential,
666 StochasticsDMethod::Ratio,
667 );
668
669 for i in 1..=4 {
676 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
677 assert!(!stoch.initialized, "Should not be initialized at bar {i}");
678 }
679
680 for i in 5..=15 {
682 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
683 }
684
685 assert!(
686 stoch.initialized,
687 "Should be initialized after sufficient bars"
688 );
689 }
690
691 #[rstest]
692 fn test_warmup_period_with_ma_d_method() {
693 let mut stoch = Stochastics::new_with_params(
694 5,
695 3,
696 3,
697 MovingAverageType::Exponential,
698 StochasticsDMethod::MovingAverage, );
700
701 for i in 1..=4 {
702 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
703 }
704 assert!(!stoch.initialized);
705
706 for i in 5..=20 {
708 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
709 }
710
711 assert!(
712 stoch.initialized,
713 "Should be initialized after sufficient bars"
714 );
715 }
716
717 #[rstest]
718 fn test_reset_clears_slowing_ma_state() {
719 let mut stoch = Stochastics::new_with_params(
720 5,
721 3,
722 3,
723 MovingAverageType::Exponential,
724 StochasticsDMethod::MovingAverage,
725 );
726
727 for i in 1..=10 {
729 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
730 }
731
732 assert!(stoch.has_inputs);
733
734 stoch.reset();
736
737 assert!(!stoch.has_inputs);
738 assert!(!stoch.initialized);
739 assert_eq!(stoch.value_k, 0.0);
740 assert_eq!(stoch.value_d, 0.0);
741 assert_eq!(stoch.highs.len(), 0);
742 assert_eq!(stoch.lows.len(), 0);
743
744 for i in 1..=10 {
746 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
747 }
748 assert!(stoch.value_k > 0.0);
749 }
750
751 #[rstest]
752 fn test_slowing_1_bypasses_ma() {
753 let stoch = Stochastics::new_with_params(
754 10,
755 3,
756 1, MovingAverageType::Exponential,
758 StochasticsDMethod::Ratio,
759 );
760
761 assert!(
762 stoch.slowing_ma.is_none(),
763 "slowing = 1 should not create MA"
764 );
765 }
766
767 #[rstest]
768 #[should_panic(expected = "slowing")]
769 fn test_slowing_0_panics() {
770 let _ = Stochastics::new_with_params(
771 10,
772 3,
773 0, MovingAverageType::Exponential,
775 StochasticsDMethod::Ratio,
776 );
777 }
778
779 #[rstest]
780 fn test_division_by_zero_protection() {
781 let mut stoch = Stochastics::new_with_params(
782 5,
783 3,
784 3,
785 MovingAverageType::Exponential,
786 StochasticsDMethod::MovingAverage,
787 );
788
789 for _ in 0..10 {
791 stoch.update_raw(100.0, 100.0, 100.0);
792 }
793
794 assert!(stoch.value_k.is_finite());
796 assert!(stoch.value_d.is_finite());
797 }
798}