nautilus_indicators/volatility/
atr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use nautilus_model::data::Bar;
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25/// An indicator which calculates a Average True Range (ATR) across a rolling window.
26#[repr(C)]
27#[derive(Debug)]
28#[cfg_attr(
29    feature = "python",
30    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
31)]
32pub struct AverageTrueRange {
33    pub period: usize,
34    pub ma_type: MovingAverageType,
35    pub use_previous: bool,
36    pub value_floor: f64,
37    pub value: f64,
38    pub count: usize,
39    pub initialized: bool,
40    ma: Box<dyn MovingAverage + Send + 'static>,
41    has_inputs: bool,
42    previous_close: f64,
43}
44
45impl Display for AverageTrueRange {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(
48            f,
49            "{}({},{},{},{})",
50            self.name(),
51            self.period,
52            self.ma_type,
53            self.use_previous,
54            self.value_floor,
55        )
56    }
57}
58
59impl Indicator for AverageTrueRange {
60    fn name(&self) -> String {
61        stringify!(AverageTrueRange).to_string()
62    }
63
64    fn has_inputs(&self) -> bool {
65        self.has_inputs
66    }
67
68    fn initialized(&self) -> bool {
69        self.initialized
70    }
71
72    fn handle_bar(&mut self, bar: &Bar) {
73        self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
74    }
75
76    fn reset(&mut self) {
77        self.previous_close = 0.0;
78        self.value = 0.0;
79        self.count = 0;
80        self.has_inputs = false;
81        self.initialized = false;
82    }
83}
84
85impl AverageTrueRange {
86    /// Creates a new [`AverageTrueRange`] instance.
87    #[must_use]
88    pub fn new(
89        period: usize,
90        ma_type: Option<MovingAverageType>,
91        use_previous: Option<bool>,
92        value_floor: Option<f64>,
93    ) -> Self {
94        Self {
95            period,
96            ma_type: ma_type.unwrap_or(MovingAverageType::Simple),
97            use_previous: use_previous.unwrap_or(true),
98            value_floor: value_floor.unwrap_or(0.0),
99            value: 0.0,
100            count: 0,
101            previous_close: 0.0,
102            ma: MovingAverageFactory::create(MovingAverageType::Simple, period),
103            has_inputs: false,
104            initialized: false,
105        }
106    }
107
108    pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
109        if self.use_previous {
110            if !self.has_inputs {
111                self.previous_close = close;
112            }
113            self.ma.update_raw(
114                f64::max(self.previous_close, high) - f64::min(low, self.previous_close),
115            );
116            self.previous_close = close;
117        } else {
118            self.ma.update_raw(high - low);
119        }
120
121        self._floor_value();
122        self.increment_count();
123    }
124
125    fn _floor_value(&mut self) {
126        if self.value_floor == 0.0 || self.value_floor < self.ma.value() {
127            self.value = self.ma.value();
128        } else {
129            // Floor the value
130            self.value = self.value_floor;
131        }
132    }
133
134    const fn increment_count(&mut self) {
135        self.count += 1;
136
137        if !self.initialized {
138            self.has_inputs = true;
139            if self.count >= self.period {
140                self.initialized = true;
141            }
142        }
143    }
144}
145
146////////////////////////////////////////////////////////////////////////////////
147// Tests
148////////////////////////////////////////////////////////////////////////////////
149#[cfg(test)]
150mod tests {
151    use rstest::rstest;
152
153    use super::*;
154    use crate::testing::approx_equal;
155
156    #[rstest]
157    fn test_name_returns_expected_string() {
158        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
159        assert_eq!(atr.name(), "AverageTrueRange");
160    }
161
162    #[rstest]
163    fn test_str_repr_returns_expected_string() {
164        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), Some(true), Some(0.0));
165        assert_eq!(format!("{atr}"), "AverageTrueRange(10,SIMPLE,true,0)");
166    }
167
168    #[rstest]
169    fn test_period() {
170        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
171        assert_eq!(atr.period, 10);
172    }
173
174    #[rstest]
175    fn test_initialized_without_inputs_returns_false() {
176        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
177        assert!(!atr.initialized());
178    }
179
180    #[rstest]
181    fn test_initialized_with_required_inputs_returns_true() {
182        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
183        for _ in 0..10 {
184            atr.update_raw(1.0, 1.0, 1.0);
185        }
186        assert!(atr.initialized());
187    }
188
189    #[rstest]
190    fn test_value_with_no_inputs_returns_zero() {
191        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
192        assert_eq!(atr.value, 0.0);
193    }
194
195    #[rstest]
196    fn test_value_with_epsilon_input() {
197        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
198        let epsilon = std::f64::EPSILON;
199        atr.update_raw(epsilon, epsilon, epsilon);
200        assert_eq!(atr.value, 0.0);
201    }
202
203    #[rstest]
204    fn test_value_with_one_ones_input() {
205        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
206        atr.update_raw(1.0, 1.0, 1.0);
207        assert_eq!(atr.value, 0.0);
208    }
209
210    #[rstest]
211    fn test_value_with_one_input() {
212        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
213        atr.update_raw(1.00020, 1.0, 1.00010);
214        assert!(approx_equal(atr.value, 0.0002));
215    }
216
217    #[rstest]
218    fn test_value_with_three_inputs() {
219        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
220        atr.update_raw(1.00020, 1.0, 1.00010);
221        atr.update_raw(1.00020, 1.0, 1.00010);
222        atr.update_raw(1.00020, 1.0, 1.00010);
223        assert!(approx_equal(atr.value, 0.0002));
224    }
225
226    #[rstest]
227    fn test_value_with_close_on_high() {
228        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
229        let mut high = 1.00010;
230        let mut low = 1.0;
231        for _ in 0..1000 {
232            high += 0.00010;
233            low += 0.00010;
234            let close = high;
235            atr.update_raw(high, low, close);
236        }
237        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
238    }
239
240    #[rstest]
241    fn test_value_with_close_on_low() {
242        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
243        let mut high = 1.00010;
244        let mut low = 1.0;
245        for _ in 0..1000 {
246            high -= 0.00010;
247            low -= 0.00010;
248            let close = low;
249            atr.update_raw(high, low, close);
250        }
251        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
252    }
253
254    #[rstest]
255    fn test_floor_with_ten_ones_inputs() {
256        let floor = 0.00005;
257        let mut floored_atr =
258            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
259        for _ in 0..20 {
260            floored_atr.update_raw(1.0, 1.0, 1.0);
261        }
262        assert_eq!(floored_atr.value, 5e-05);
263    }
264
265    #[rstest]
266    fn test_floor_with_exponentially_decreasing_high_inputs() {
267        let floor = 0.00005;
268        let mut floored_atr =
269            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
270        let mut high = 1.00020;
271        let low = 1.0;
272        let close = 1.0;
273        for _ in 0..20 {
274            high -= (high - low) / 2.0;
275            floored_atr.update_raw(high, low, close);
276        }
277        assert_eq!(floored_atr.value, floor);
278    }
279
280    #[rstest]
281    fn test_reset_successfully_returns_indicator_to_fresh_state() {
282        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
283        for _ in 0..1000 {
284            atr.update_raw(1.00010, 1.0, 1.00005);
285        }
286        atr.reset();
287        assert!(!atr.initialized);
288        assert_eq!(atr.value, 0.0);
289    }
290}