nautilus_indicators/momentum/
rsi.rs1use std::fmt::{Debug, Display};
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::{
24 average::{MovingAverageFactory, MovingAverageType},
25 indicator::{Indicator, MovingAverage},
26};
27
28#[repr(C)]
30#[derive(Debug)]
31#[cfg_attr(
32 feature = "python",
33 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
34)]
35pub struct RelativeStrengthIndex {
36 pub period: usize,
37 pub ma_type: MovingAverageType,
38 pub value: f64,
39 pub count: usize,
40 pub initialized: bool,
41 has_inputs: bool,
42 last_value: f64,
43 average_gain: Box<dyn MovingAverage + Send + 'static>,
44 average_loss: Box<dyn MovingAverage + Send + 'static>,
45 rsi_max: f64,
46}
47
48impl Display for RelativeStrengthIndex {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}({},{})", self.name(), self.period, self.ma_type)
51 }
52}
53
54impl Indicator for RelativeStrengthIndex {
55 fn name(&self) -> String {
56 stringify!(RelativeStrengthIndex).to_string()
57 }
58
59 fn has_inputs(&self) -> bool {
60 self.has_inputs
61 }
62
63 fn initialized(&self) -> bool {
64 self.initialized
65 }
66
67 fn handle_quote(&mut self, quote: &QuoteTick) {
68 self.update_raw(quote.extract_price(PriceType::Mid).into());
69 }
70
71 fn handle_trade(&mut self, trade: &TradeTick) {
72 self.update_raw((trade.price).into());
73 }
74
75 fn handle_bar(&mut self, bar: &Bar) {
76 self.update_raw((&bar.close).into());
77 }
78
79 fn reset(&mut self) {
80 self.value = 0.0;
81 self.last_value = 0.0;
82 self.count = 0;
83 self.has_inputs = false;
84 self.initialized = false;
85 }
86}
87
88impl RelativeStrengthIndex {
89 #[must_use]
91 pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
92 Self {
93 period,
94 ma_type: ma_type.unwrap_or(MovingAverageType::Exponential),
95 value: 0.0,
96 last_value: 0.0,
97 count: 0,
98 has_inputs: false,
100 average_gain: MovingAverageFactory::create(MovingAverageType::Exponential, period),
101 average_loss: MovingAverageFactory::create(MovingAverageType::Exponential, period),
102 rsi_max: 1.0,
103 initialized: false,
104 }
105 }
106
107 pub fn update_raw(&mut self, value: f64) {
108 if !self.has_inputs {
109 self.last_value = value;
110 self.has_inputs = true;
111 }
112 let gain = value - self.last_value;
113 if gain > 0.0 {
114 self.average_gain.update_raw(gain);
115 self.average_loss.update_raw(0.0);
116 } else if gain < 0.0 {
117 self.average_loss.update_raw(-gain);
118 self.average_gain.update_raw(0.0);
119 } else {
120 self.average_loss.update_raw(0.0);
121 self.average_gain.update_raw(0.0);
122 }
123 self.count = self.average_gain.count();
125 if !self.initialized && self.average_loss.initialized() && self.average_gain.initialized() {
126 self.initialized = true;
127 }
128
129 if self.average_loss.value() == 0.0 {
130 self.value = self.rsi_max;
131 return;
132 }
133
134 let rs = self.average_gain.value() / self.average_loss.value();
135 self.value = self.rsi_max - (self.rsi_max / (1.0 + rs));
136 self.last_value = value;
137
138 if !self.initialized && self.count >= self.period {
139 self.initialized = true;
140 }
141 }
142}
143
144#[cfg(test)]
148mod tests {
149 use nautilus_model::data::{Bar, QuoteTick, TradeTick};
150 use rstest::rstest;
151
152 use crate::{indicator::Indicator, momentum::rsi::RelativeStrengthIndex, stubs::*};
153
154 #[rstest]
155 fn test_rsi_initialized(rsi_10: RelativeStrengthIndex) {
156 let display_str = format!("{rsi_10}");
157 assert_eq!(display_str, "RelativeStrengthIndex(10,EXPONENTIAL)");
158 assert_eq!(rsi_10.period, 10);
159 assert!(!rsi_10.initialized);
160 }
161
162 #[rstest]
163 fn test_initialized_with_required_inputs_returns_true(mut rsi_10: RelativeStrengthIndex) {
164 for i in 0..12 {
165 rsi_10.update_raw(f64::from(i));
166 }
167 assert!(rsi_10.initialized);
168 }
169
170 #[rstest]
171 fn test_value_with_one_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
172 rsi_10.update_raw(1.0);
173 assert_eq!(rsi_10.value, 1.0);
174 }
175
176 #[rstest]
177 fn test_value_all_higher_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
178 for i in 1..4 {
179 rsi_10.update_raw(f64::from(i));
180 }
181 assert_eq!(rsi_10.value, 1.0);
182 }
183
184 #[rstest]
185 fn test_value_with_all_lower_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
186 for i in (1..4).rev() {
187 rsi_10.update_raw(f64::from(i));
188 }
189 assert_eq!(rsi_10.value, 0.0);
190 }
191
192 #[rstest]
193 fn test_value_with_various_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
194 rsi_10.update_raw(3.0);
195 rsi_10.update_raw(2.0);
196 rsi_10.update_raw(5.0);
197 rsi_10.update_raw(6.0);
198 rsi_10.update_raw(7.0);
199 rsi_10.update_raw(6.0);
200
201 assert_eq!(rsi_10.value, 0.683_736_332_582_526_5);
202 }
203
204 #[rstest]
205 fn test_value_at_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
206 rsi_10.update_raw(3.0);
207 rsi_10.update_raw(2.0);
208 rsi_10.update_raw(5.0);
209 rsi_10.update_raw(6.0);
210 rsi_10.update_raw(7.0);
211 rsi_10.update_raw(6.0);
212 rsi_10.update_raw(6.0);
213 rsi_10.update_raw(7.0);
214
215 assert_eq!(rsi_10.value, 0.761_534_466_766_272_5);
216 }
217
218 #[rstest]
219 fn test_reset(mut rsi_10: RelativeStrengthIndex) {
220 rsi_10.update_raw(1.0);
221 rsi_10.update_raw(2.0);
222 rsi_10.reset();
223 assert!(!rsi_10.initialized());
224 assert_eq!(rsi_10.count, 0);
225 }
226
227 #[rstest]
228 fn test_handle_quote_tick(mut rsi_10: RelativeStrengthIndex, stub_quote: QuoteTick) {
229 rsi_10.handle_quote(&stub_quote);
230 assert_eq!(rsi_10.count, 1);
231 assert_eq!(rsi_10.value, 1.0);
232 }
233
234 #[rstest]
235 fn test_handle_trade_tick(mut rsi_10: RelativeStrengthIndex, stub_trade: TradeTick) {
236 rsi_10.handle_trade(&stub_trade);
237 assert_eq!(rsi_10.count, 1);
238 assert_eq!(rsi_10.value, 1.0);
239 }
240
241 #[rstest]
242 fn test_handle_bar(mut rsi_10: RelativeStrengthIndex, bar_ethusdt_binance_minute_bid: Bar) {
243 rsi_10.handle_bar(&bar_ethusdt_binance_minute_bid);
244 assert_eq!(rsi_10.count, 1);
245 assert_eq!(rsi_10.value, 1.0);
246 }
247}