1use std::time::Duration;
25
26use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
27use rand::Rng;
28
29#[derive(Clone, Debug)]
30pub struct ExponentialBackoff {
31 delay_initial: Duration,
33 delay_max: Duration,
35 delay_current: Duration,
37 factor: f64,
39 jitter_ms: u64,
41 immediate_reconnect: bool,
43 immediate_reconnect_original: bool,
45}
46
47impl ExponentialBackoff {
55 pub fn new(
65 delay_initial: Duration,
66 delay_max: Duration,
67 factor: f64,
68 jitter_ms: u64,
69 immediate_first: bool,
70 ) -> anyhow::Result<Self> {
71 check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72 check_predicate_true(
73 delay_max >= delay_initial,
74 "delay_max must be >= delay_initial",
75 )?;
76 check_predicate_true(
77 delay_max.as_nanos() <= u128::from(u64::MAX),
78 "delay_max exceeds maximum representable duration (≈584 years)",
79 )?;
80 check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
81
82 Ok(Self {
83 delay_initial,
84 delay_max,
85 delay_current: delay_initial,
86 factor,
87 jitter_ms,
88 immediate_reconnect: immediate_first,
89 immediate_reconnect_original: immediate_first,
90 })
91 }
92
93 pub fn next_duration(&mut self) -> Duration {
99 if self.immediate_reconnect && self.delay_current == self.delay_initial {
100 self.immediate_reconnect = false;
101 return Duration::ZERO;
102 }
103
104 let jitter = rand::rng().random_range(0..=self.jitter_ms);
106 let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
107
108 let current_nanos = self.delay_current.as_nanos();
111 let max_nanos = self.delay_max.as_nanos();
112
113 let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
115 max_nanos
117 } else {
118 let current_u64 = current_nanos as u64;
119 let next_f64 = current_u64 as f64 * self.factor;
120
121 if next_f64 > u64::MAX as f64 {
123 u128::from(u64::MAX)
124 } else {
125 u128::from(next_f64 as u64)
126 }
127 };
128
129 let clamped = std::cmp::min(next_nanos_u128, max_nanos);
130 let final_nanos = if clamped > u128::from(u64::MAX) {
131 u64::MAX
132 } else {
133 clamped as u64
134 };
135
136 self.delay_current = Duration::from_nanos(final_nanos);
137
138 delay_with_jitter
139 }
140
141 pub const fn reset(&mut self) {
143 self.delay_current = self.delay_initial;
144 self.immediate_reconnect = self.immediate_reconnect_original;
145 }
146
147 #[must_use]
151 pub const fn current_delay(&self) -> Duration {
152 self.delay_current
153 }
154}
155
156#[cfg(test)]
160mod tests {
161 use std::time::Duration;
162
163 use rstest::rstest;
164
165 use super::*;
166
167 #[rstest]
168 fn test_no_jitter_exponential_growth() {
169 let initial = Duration::from_millis(100);
170 let max = Duration::from_millis(1600);
171 let factor = 2.0;
172 let jitter = 0;
173 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
174
175 let d1 = backoff.next_duration();
177 assert_eq!(d1, Duration::from_millis(100));
178
179 let d2 = backoff.next_duration();
181 assert_eq!(d2, Duration::from_millis(200));
182
183 let d3 = backoff.next_duration();
185 assert_eq!(d3, Duration::from_millis(400));
186
187 let d4 = backoff.next_duration();
189 assert_eq!(d4, Duration::from_millis(800));
190
191 let d5 = backoff.next_duration();
193 assert_eq!(d5, Duration::from_millis(1600));
194
195 let d6 = backoff.next_duration();
197 assert_eq!(d6, Duration::from_millis(1600));
198 }
199
200 #[rstest]
201 fn test_reset() {
202 let initial = Duration::from_millis(100);
203 let max = Duration::from_millis(1600);
204 let factor = 2.0;
205 let jitter = 0;
206 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
207
208 let _ = backoff.next_duration(); backoff.reset();
211 let d = backoff.next_duration();
212 assert_eq!(d, Duration::from_millis(100));
214 }
215
216 #[rstest]
217 fn test_jitter_within_bounds() {
218 let initial = Duration::from_millis(100);
219 let max = Duration::from_millis(1000);
220 let factor = 2.0;
221 let jitter = 50;
222 for _ in 0..10 {
224 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
225 let base = backoff.delay_current;
227 let delay = backoff.next_duration();
228 let min_expected = base;
230 let max_expected = base + Duration::from_millis(jitter);
231 assert!(
232 delay >= min_expected,
233 "Delay {delay:?} is less than expected minimum {min_expected:?}"
234 );
235 assert!(
236 delay <= max_expected,
237 "Delay {delay:?} exceeds expected maximum {max_expected:?}"
238 );
239 }
240 }
241
242 #[rstest]
243 fn test_factor_less_than_two() {
244 let initial = Duration::from_millis(100);
245 let max = Duration::from_millis(200);
246 let factor = 1.5;
247 let jitter = 0;
248 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
249
250 let d1 = backoff.next_duration();
252 assert_eq!(d1, Duration::from_millis(100));
253
254 let d2 = backoff.next_duration();
256 assert_eq!(d2, Duration::from_millis(150));
257
258 let d3 = backoff.next_duration();
260 assert_eq!(d3, Duration::from_millis(200));
261
262 let d4 = backoff.next_duration();
264 assert_eq!(d4, Duration::from_millis(200));
265 }
266
267 #[rstest]
268 fn test_max_delay_is_respected() {
269 let initial = Duration::from_millis(500);
270 let max = Duration::from_millis(1000);
271 let factor = 3.0;
272 let jitter = 0;
273 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
274
275 let d1 = backoff.next_duration();
277 assert_eq!(d1, Duration::from_millis(500));
278
279 let d2 = backoff.next_duration();
281 assert_eq!(d2, Duration::from_millis(1000));
282
283 let d3 = backoff.next_duration();
285 assert_eq!(d3, Duration::from_millis(1000));
286 }
287
288 #[rstest]
289 fn test_current_delay_getter() {
290 let initial = Duration::from_millis(100);
291 let max = Duration::from_millis(1600);
292 let factor = 2.0;
293 let jitter = 0;
294 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
295
296 assert_eq!(backoff.current_delay(), initial);
297
298 let _ = backoff.next_duration();
299 assert_eq!(backoff.current_delay(), Duration::from_millis(200));
300
301 let _ = backoff.next_duration();
302 assert_eq!(backoff.current_delay(), Duration::from_millis(400));
303
304 backoff.reset();
305 assert_eq!(backoff.current_delay(), initial);
306 }
307
308 #[rstest]
309 fn test_validation_zero_initial_delay() {
310 let result =
311 ExponentialBackoff::new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
312 assert!(result.is_err());
313 assert!(
314 result
315 .unwrap_err()
316 .to_string()
317 .contains("delay_initial must be non-zero")
318 );
319 }
320
321 #[rstest]
322 fn test_validation_max_less_than_initial() {
323 let result = ExponentialBackoff::new(
324 Duration::from_millis(1000),
325 Duration::from_millis(500),
326 2.0,
327 0,
328 false,
329 );
330 assert!(result.is_err());
331 assert!(
332 result
333 .unwrap_err()
334 .to_string()
335 .contains("delay_max must be >= delay_initial")
336 );
337 }
338
339 #[rstest]
340 fn test_validation_factor_too_small() {
341 let result = ExponentialBackoff::new(
342 Duration::from_millis(100),
343 Duration::from_millis(1000),
344 0.5,
345 0,
346 false,
347 );
348 assert!(result.is_err());
349 assert!(result.unwrap_err().to_string().contains("factor"));
350 }
351
352 #[rstest]
353 fn test_validation_factor_too_large() {
354 let result = ExponentialBackoff::new(
355 Duration::from_millis(100),
356 Duration::from_millis(1000),
357 150.0,
358 0,
359 false,
360 );
361 assert!(result.is_err());
362 assert!(result.unwrap_err().to_string().contains("factor"));
363 }
364
365 #[rstest]
366 fn test_validation_delay_max_exceeds_u64_max_nanos() {
367 let max_valid = Duration::from_nanos(u64::MAX);
370 let too_large = max_valid + Duration::from_nanos(1);
371
372 let result = ExponentialBackoff::new(Duration::from_millis(100), too_large, 2.0, 0, false);
373 assert!(result.is_err());
374 assert!(
375 result
376 .unwrap_err()
377 .to_string()
378 .contains("delay_max exceeds maximum representable duration")
379 );
380 }
381
382 #[rstest]
383 fn test_immediate_first() {
384 let initial = Duration::from_millis(100);
385 let max = Duration::from_millis(1600);
386 let factor = 2.0;
387 let jitter = 0;
388 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
389
390 let d1 = backoff.next_duration();
392 assert_eq!(
393 d1,
394 Duration::ZERO,
395 "Expected immediate reconnect (zero delay) on first call"
396 );
397
398 let d2 = backoff.next_duration();
400 assert_eq!(
401 d2, initial,
402 "Expected the delay to be the initial delay after immediate reconnect"
403 );
404
405 let d3 = backoff.next_duration();
407 let expected = initial * 2; assert_eq!(
409 d3, expected,
410 "Expected exponential growth from the initial delay"
411 );
412 }
413
414 #[rstest]
415 fn test_reset_restores_immediate_first() {
416 let initial = Duration::from_millis(100);
417 let max = Duration::from_millis(1600);
418 let factor = 2.0;
419 let jitter = 0;
420 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
421
422 let d1 = backoff.next_duration();
424 assert_eq!(d1, Duration::ZERO);
425
426 let d2 = backoff.next_duration();
428 assert_eq!(d2, initial);
429
430 backoff.reset();
432 let d3 = backoff.next_duration();
433 assert_eq!(
434 d3,
435 Duration::ZERO,
436 "Reset should restore immediate_first behavior"
437 );
438 }
439}