nautilus_indicators/volatility/
atr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use nautilus_model::data::Bar;
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25/// An indicator which calculates an Average True Range (ATR) across a rolling window.
26#[repr(C)]
27#[derive(Debug)]
28#[cfg_attr(
29    feature = "python",
30    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
31)]
32pub struct AverageTrueRange {
33    pub period: usize,
34    pub ma_type: MovingAverageType,
35    pub use_previous: bool,
36    pub value_floor: f64,
37    pub value: f64,
38    pub count: usize,
39    pub initialized: bool,
40    ma: Box<dyn MovingAverage + Send + 'static>,
41    has_inputs: bool,
42    previous_close: f64,
43}
44
45impl Display for AverageTrueRange {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(
48            f,
49            "{}({},{},{},{})",
50            self.name(),
51            self.period,
52            self.ma_type,
53            self.use_previous,
54            self.value_floor,
55        )
56    }
57}
58
59impl Indicator for AverageTrueRange {
60    fn name(&self) -> String {
61        stringify!(AverageTrueRange).to_string()
62    }
63
64    fn has_inputs(&self) -> bool {
65        self.has_inputs
66    }
67
68    fn initialized(&self) -> bool {
69        self.initialized
70    }
71
72    fn handle_bar(&mut self, bar: &Bar) {
73        self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
74    }
75
76    fn reset(&mut self) {
77        self.previous_close = 0.0;
78        self.value = 0.0;
79        self.count = 0;
80        self.has_inputs = false;
81        self.initialized = false;
82    }
83}
84
85impl AverageTrueRange {
86    /// Creates a new [`AverageTrueRange`] instance.
87    #[must_use]
88    pub fn new(
89        period: usize,
90        ma_type: Option<MovingAverageType>,
91        use_previous: Option<bool>,
92        value_floor: Option<f64>,
93    ) -> Self {
94        Self {
95            period,
96            ma_type: ma_type.unwrap_or(MovingAverageType::Simple),
97            use_previous: use_previous.unwrap_or(true),
98            value_floor: value_floor.unwrap_or(0.0),
99            value: 0.0,
100            count: 0,
101            previous_close: 0.0,
102            ma: MovingAverageFactory::create(MovingAverageType::Simple, period),
103            has_inputs: false,
104            initialized: false,
105        }
106    }
107
108    pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
109        if self.use_previous {
110            if !self.has_inputs {
111                self.previous_close = close;
112            }
113            self.ma.update_raw(
114                f64::max(self.previous_close, high) - f64::min(low, self.previous_close),
115            );
116            self.previous_close = close;
117        } else {
118            self.ma.update_raw(high - low);
119        }
120
121        self._floor_value();
122        self.increment_count();
123    }
124
125    fn _floor_value(&mut self) {
126        if self.value_floor == 0.0 || self.value_floor < self.ma.value() {
127            self.value = self.ma.value();
128        } else {
129            // Floor the value
130            self.value = self.value_floor;
131        }
132    }
133
134    const fn increment_count(&mut self) {
135        self.count += 1;
136
137        if !self.initialized {
138            self.has_inputs = true;
139            if self.count >= self.period {
140                self.initialized = true;
141            }
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use rstest::rstest;
149
150    use super::*;
151    use crate::testing::approx_equal;
152
153    #[rstest]
154    fn test_name_returns_expected_string() {
155        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
156        assert_eq!(atr.name(), "AverageTrueRange");
157    }
158
159    #[rstest]
160    fn test_str_repr_returns_expected_string() {
161        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), Some(true), Some(0.0));
162        assert_eq!(format!("{atr}"), "AverageTrueRange(10,SIMPLE,true,0)");
163    }
164
165    #[rstest]
166    fn test_period() {
167        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
168        assert_eq!(atr.period, 10);
169    }
170
171    #[rstest]
172    fn test_initialized_without_inputs_returns_false() {
173        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
174        assert!(!atr.initialized());
175    }
176
177    #[rstest]
178    fn test_initialized_with_required_inputs_returns_true() {
179        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
180        for _ in 0..10 {
181            atr.update_raw(1.0, 1.0, 1.0);
182        }
183        assert!(atr.initialized());
184    }
185
186    #[rstest]
187    fn test_value_with_no_inputs_returns_zero() {
188        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
189        assert_eq!(atr.value, 0.0);
190    }
191
192    #[rstest]
193    fn test_value_with_epsilon_input() {
194        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
195        let epsilon = f64::EPSILON;
196        atr.update_raw(epsilon, epsilon, epsilon);
197        assert_eq!(atr.value, 0.0);
198    }
199
200    #[rstest]
201    fn test_value_with_one_ones_input() {
202        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
203        atr.update_raw(1.0, 1.0, 1.0);
204        assert_eq!(atr.value, 0.0);
205    }
206
207    #[rstest]
208    fn test_value_with_one_input() {
209        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
210        atr.update_raw(1.00020, 1.0, 1.00010);
211        assert!(approx_equal(atr.value, 0.0002));
212    }
213
214    #[rstest]
215    fn test_value_with_three_inputs() {
216        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
217        atr.update_raw(1.00020, 1.0, 1.00010);
218        atr.update_raw(1.00020, 1.0, 1.00010);
219        atr.update_raw(1.00020, 1.0, 1.00010);
220        assert!(approx_equal(atr.value, 0.0002));
221    }
222
223    #[rstest]
224    fn test_value_with_close_on_high() {
225        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
226        let mut high = 1.00010;
227        let mut low = 1.0;
228        for _ in 0..1000 {
229            high += 0.00010;
230            low += 0.00010;
231            let close = high;
232            atr.update_raw(high, low, close);
233        }
234        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
235    }
236
237    #[rstest]
238    fn test_value_with_close_on_low() {
239        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
240        let mut high = 1.00010;
241        let mut low = 1.0;
242        for _ in 0..1000 {
243            high -= 0.00010;
244            low -= 0.00010;
245            let close = low;
246            atr.update_raw(high, low, close);
247        }
248        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
249    }
250
251    #[rstest]
252    fn test_floor_with_ten_ones_inputs() {
253        let floor = 0.00005;
254        let mut floored_atr =
255            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
256        for _ in 0..20 {
257            floored_atr.update_raw(1.0, 1.0, 1.0);
258        }
259        assert_eq!(floored_atr.value, 5e-05);
260    }
261
262    #[rstest]
263    fn test_floor_with_exponentially_decreasing_high_inputs() {
264        let floor = 0.00005;
265        let mut floored_atr =
266            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
267        let mut high = 1.00020;
268        let low = 1.0;
269        let close = 1.0;
270        for _ in 0..20 {
271            high -= (high - low) / 2.0;
272            floored_atr.update_raw(high, low, close);
273        }
274        assert_eq!(floored_atr.value, floor);
275    }
276
277    #[rstest]
278    fn test_reset_successfully_returns_indicator_to_fresh_state() {
279        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
280        for _ in 0..1000 {
281            atr.update_raw(1.00010, 1.0, 1.00005);
282        }
283        atr.reset();
284        assert!(!atr.initialized);
285        assert_eq!(atr.value, 0.0);
286    }
287}