nautilus_indicators/momentum/
rsi.rs1use std::fmt::{Debug, Display};
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::{
24 average::{MovingAverageFactory, MovingAverageType},
25 indicator::{Indicator, MovingAverage},
26};
27
28#[repr(C)]
30#[derive(Debug)]
31#[cfg_attr(
32 feature = "python",
33 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
34)]
35pub struct RelativeStrengthIndex {
36 pub period: usize,
37 pub ma_type: MovingAverageType,
38 pub value: f64,
39 pub count: usize,
40 pub initialized: bool,
41 has_inputs: bool,
42 last_value: f64,
43 average_gain: Box<dyn MovingAverage + Send + 'static>,
44 average_loss: Box<dyn MovingAverage + Send + 'static>,
45 rsi_max: f64,
46}
47
48impl Display for RelativeStrengthIndex {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}({},{})", self.name(), self.period, self.ma_type)
51 }
52}
53
54impl Indicator for RelativeStrengthIndex {
55 fn name(&self) -> String {
56 stringify!(RelativeStrengthIndex).to_string()
57 }
58
59 fn has_inputs(&self) -> bool {
60 self.has_inputs
61 }
62
63 fn initialized(&self) -> bool {
64 self.initialized
65 }
66
67 fn handle_quote(&mut self, quote: &QuoteTick) {
68 self.update_raw(quote.extract_price(PriceType::Mid).into());
69 }
70
71 fn handle_trade(&mut self, trade: &TradeTick) {
72 self.update_raw((trade.price).into());
73 }
74
75 fn handle_bar(&mut self, bar: &Bar) {
76 self.update_raw((&bar.close).into());
77 }
78
79 fn reset(&mut self) {
80 self.value = 0.0;
81 self.last_value = 0.0;
82 self.count = 0;
83 self.has_inputs = false;
84 self.initialized = false;
85 self.average_gain.reset();
86 self.average_loss.reset();
87 }
88}
89
90impl RelativeStrengthIndex {
91 #[must_use]
93 pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
94 Self {
95 period,
96 ma_type: ma_type.unwrap_or(MovingAverageType::Exponential),
97 value: 0.0,
98 last_value: 0.0,
99 count: 0,
100 has_inputs: false,
101 average_gain: MovingAverageFactory::create(MovingAverageType::Exponential, period),
102 average_loss: MovingAverageFactory::create(MovingAverageType::Exponential, period),
103 rsi_max: 1.0,
104 initialized: false,
105 }
106 }
107
108 pub fn update_raw(&mut self, value: f64) {
109 if !self.has_inputs {
110 self.last_value = value;
111 self.has_inputs = true;
112 }
113 let gain = value - self.last_value;
114 if gain > 0.0 {
115 self.average_gain.update_raw(gain);
116 self.average_loss.update_raw(0.0);
117 } else if gain < 0.0 {
118 self.average_loss.update_raw(-gain);
119 self.average_gain.update_raw(0.0);
120 } else {
121 self.average_loss.update_raw(0.0);
122 self.average_gain.update_raw(0.0);
123 }
124 self.count = self.average_gain.count();
125 if !self.initialized && self.average_loss.initialized() && self.average_gain.initialized() {
126 self.initialized = true;
127 }
128
129 if self.average_loss.value() == 0.0 {
130 self.value = self.rsi_max;
131 return;
132 }
133
134 let rs = self.average_gain.value() / self.average_loss.value();
135 self.value = self.rsi_max - (self.rsi_max / (1.0 + rs));
136 self.last_value = value;
137
138 if !self.initialized && self.count >= self.period {
139 self.initialized = true;
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use nautilus_model::data::{Bar, QuoteTick, TradeTick};
147 use rstest::rstest;
148
149 use crate::{indicator::Indicator, momentum::rsi::RelativeStrengthIndex, stubs::*};
150
151 #[rstest]
152 fn test_rsi_initialized(rsi_10: RelativeStrengthIndex) {
153 let display_str = format!("{rsi_10}");
154 assert_eq!(display_str, "RelativeStrengthIndex(10,EXPONENTIAL)");
155 assert_eq!(rsi_10.period, 10);
156 assert!(!rsi_10.initialized);
157 }
158
159 #[rstest]
160 fn test_initialized_with_required_inputs_returns_true(mut rsi_10: RelativeStrengthIndex) {
161 for i in 0..12 {
162 rsi_10.update_raw(f64::from(i));
163 }
164 assert!(rsi_10.initialized);
165 }
166
167 #[rstest]
168 fn test_value_with_one_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
169 rsi_10.update_raw(1.0);
170 assert_eq!(rsi_10.value, 1.0);
171 }
172
173 #[rstest]
174 fn test_value_all_higher_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
175 for i in 1..4 {
176 rsi_10.update_raw(f64::from(i));
177 }
178 assert_eq!(rsi_10.value, 1.0);
179 }
180
181 #[rstest]
182 fn test_value_with_all_lower_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
183 for i in (1..4).rev() {
184 rsi_10.update_raw(f64::from(i));
185 }
186 assert_eq!(rsi_10.value, 0.0);
187 }
188
189 #[rstest]
190 fn test_value_with_various_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
191 rsi_10.update_raw(3.0);
192 rsi_10.update_raw(2.0);
193 rsi_10.update_raw(5.0);
194 rsi_10.update_raw(6.0);
195 rsi_10.update_raw(7.0);
196 rsi_10.update_raw(6.0);
197
198 assert_eq!(rsi_10.value, 0.683_736_332_582_526_5);
199 }
200
201 #[rstest]
202 fn test_value_at_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
203 rsi_10.update_raw(3.0);
204 rsi_10.update_raw(2.0);
205 rsi_10.update_raw(5.0);
206 rsi_10.update_raw(6.0);
207 rsi_10.update_raw(7.0);
208 rsi_10.update_raw(6.0);
209 rsi_10.update_raw(6.0);
210 rsi_10.update_raw(7.0);
211
212 assert_eq!(rsi_10.value, 0.761_534_466_766_272_5);
213 }
214
215 #[rstest]
216 fn test_reset(mut rsi_10: RelativeStrengthIndex) {
217 rsi_10.update_raw(1.0);
218 rsi_10.update_raw(2.0);
219 rsi_10.reset();
220 assert!(!rsi_10.initialized());
221 assert_eq!(rsi_10.count, 0);
222 }
223
224 #[rstest]
225 fn test_reset_resets_inner_mas(mut rsi_10: RelativeStrengthIndex) {
226 rsi_10.update_raw(1.0);
227 rsi_10.update_raw(2.0);
228 rsi_10.reset();
229 assert_eq!(rsi_10.average_gain.count(), 0);
230 assert_eq!(rsi_10.average_loss.count(), 0);
231 }
232
233 #[rstest]
234 fn test_handle_quote_tick(mut rsi_10: RelativeStrengthIndex, stub_quote: QuoteTick) {
235 rsi_10.handle_quote(&stub_quote);
236 assert_eq!(rsi_10.count, 1);
237 assert_eq!(rsi_10.value, 1.0);
238 }
239
240 #[rstest]
241 fn test_handle_trade_tick(mut rsi_10: RelativeStrengthIndex, stub_trade: TradeTick) {
242 rsi_10.handle_trade(&stub_trade);
243 assert_eq!(rsi_10.count, 1);
244 assert_eq!(rsi_10.value, 1.0);
245 }
246
247 #[rstest]
248 fn test_handle_bar(mut rsi_10: RelativeStrengthIndex, bar_ethusdt_binance_minute_bid: Bar) {
249 rsi_10.handle_bar(&bar_ethusdt_binance_minute_bid);
250 assert_eq!(rsi_10.count, 1);
251 assert_eq!(rsi_10.value, 1.0);
252 }
253
254 #[rstest]
255 fn test_constant_inputs_initializes_and_value_max(mut rsi_10: RelativeStrengthIndex) {
256 for _ in 0..12 {
257 rsi_10.update_raw(5.0);
258 }
259 assert!(rsi_10.initialized);
260 assert_eq!(rsi_10.value, 1.0);
261 }
262
263 #[rstest]
264 fn test_reset_resets_has_inputs_and_value(mut rsi_10: RelativeStrengthIndex) {
265 rsi_10.update_raw(1.0);
266 rsi_10.reset();
267 assert!(!rsi_10.has_inputs());
268 assert_eq!(rsi_10.value, 0.0);
269 }
270}