1use std::time::Duration;
25
26use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
27use rand::Rng;
28
29#[derive(Clone, Debug)]
30pub struct ExponentialBackoff {
31 delay_initial: Duration,
33 delay_max: Duration,
35 delay_current: Duration,
37 factor: f64,
39 jitter_ms: u64,
41 immediate_reconnect: bool,
43 immediate_reconnect_original: bool,
45}
46
47impl ExponentialBackoff {
55 pub fn new(
65 delay_initial: Duration,
66 delay_max: Duration,
67 factor: f64,
68 jitter_ms: u64,
69 immediate_first: bool,
70 ) -> anyhow::Result<Self> {
71 check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72 check_predicate_true(
73 delay_max >= delay_initial,
74 "delay_max must be >= delay_initial",
75 )?;
76 check_predicate_true(
77 delay_max.as_nanos() <= u128::from(u64::MAX),
78 "delay_max exceeds maximum representable duration (≈584 years)",
79 )?;
80 check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
81
82 Ok(Self {
83 delay_initial,
84 delay_max,
85 delay_current: delay_initial,
86 factor,
87 jitter_ms,
88 immediate_reconnect: immediate_first,
89 immediate_reconnect_original: immediate_first,
90 })
91 }
92
93 pub fn next_duration(&mut self) -> Duration {
99 if self.immediate_reconnect && self.delay_current == self.delay_initial {
100 self.immediate_reconnect = false;
101 return Duration::ZERO;
102 }
103
104 let jitter = rand::rng().random_range(0..=self.jitter_ms);
106 let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
107
108 let clamped_delay = std::cmp::min(delay_with_jitter, self.delay_max);
110
111 let current_nanos = self.delay_current.as_nanos();
114 let max_nanos = self.delay_max.as_nanos();
115
116 let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
118 max_nanos
120 } else {
121 let current_u64 = current_nanos as u64;
122 let next_f64 = current_u64 as f64 * self.factor;
123
124 if next_f64 > u64::MAX as f64 {
126 u128::from(u64::MAX)
127 } else {
128 u128::from(next_f64 as u64)
129 }
130 };
131
132 let clamped = std::cmp::min(next_nanos_u128, max_nanos);
133 let final_nanos = if clamped > u128::from(u64::MAX) {
134 u64::MAX
135 } else {
136 clamped as u64
137 };
138
139 self.delay_current = Duration::from_nanos(final_nanos);
140
141 clamped_delay
142 }
143
144 pub const fn reset(&mut self) {
146 self.delay_current = self.delay_initial;
147 self.immediate_reconnect = self.immediate_reconnect_original;
148 }
149
150 #[must_use]
154 pub const fn current_delay(&self) -> Duration {
155 self.delay_current
156 }
157}
158
159#[cfg(test)]
163mod tests {
164 use std::time::Duration;
165
166 use rstest::rstest;
167
168 use super::*;
169
170 #[rstest]
171 fn test_no_jitter_exponential_growth() {
172 let initial = Duration::from_millis(100);
173 let max = Duration::from_millis(1600);
174 let factor = 2.0;
175 let jitter = 0;
176 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
177
178 let d1 = backoff.next_duration();
180 assert_eq!(d1, Duration::from_millis(100));
181
182 let d2 = backoff.next_duration();
184 assert_eq!(d2, Duration::from_millis(200));
185
186 let d3 = backoff.next_duration();
188 assert_eq!(d3, Duration::from_millis(400));
189
190 let d4 = backoff.next_duration();
192 assert_eq!(d4, Duration::from_millis(800));
193
194 let d5 = backoff.next_duration();
196 assert_eq!(d5, Duration::from_millis(1600));
197
198 let d6 = backoff.next_duration();
200 assert_eq!(d6, Duration::from_millis(1600));
201 }
202
203 #[rstest]
204 fn test_reset() {
205 let initial = Duration::from_millis(100);
206 let max = Duration::from_millis(1600);
207 let factor = 2.0;
208 let jitter = 0;
209 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
210
211 let _ = backoff.next_duration(); backoff.reset();
214 let d = backoff.next_duration();
215 assert_eq!(d, Duration::from_millis(100));
217 }
218
219 #[rstest]
220 fn test_jitter_within_bounds() {
221 let initial = Duration::from_millis(100);
222 let max = Duration::from_millis(1000);
223 let factor = 2.0;
224 let jitter = 50;
225 for _ in 0..10 {
227 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
228 let base = backoff.delay_current;
230 let delay = backoff.next_duration();
231 let min_expected = base;
233 let max_expected = base + Duration::from_millis(jitter);
234 assert!(
235 delay >= min_expected,
236 "Delay {delay:?} is less than expected minimum {min_expected:?}"
237 );
238 assert!(
239 delay <= max_expected,
240 "Delay {delay:?} exceeds expected maximum {max_expected:?}"
241 );
242 }
243 }
244
245 #[rstest]
246 fn test_factor_less_than_two() {
247 let initial = Duration::from_millis(100);
248 let max = Duration::from_millis(200);
249 let factor = 1.5;
250 let jitter = 0;
251 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
252
253 let d1 = backoff.next_duration();
255 assert_eq!(d1, Duration::from_millis(100));
256
257 let d2 = backoff.next_duration();
259 assert_eq!(d2, Duration::from_millis(150));
260
261 let d3 = backoff.next_duration();
263 assert_eq!(d3, Duration::from_millis(200));
264
265 let d4 = backoff.next_duration();
267 assert_eq!(d4, Duration::from_millis(200));
268 }
269
270 #[rstest]
271 fn test_max_delay_is_respected() {
272 let initial = Duration::from_millis(500);
273 let max = Duration::from_millis(1000);
274 let factor = 3.0;
275 let jitter = 0;
276 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
277
278 let d1 = backoff.next_duration();
280 assert_eq!(d1, Duration::from_millis(500));
281
282 let d2 = backoff.next_duration();
284 assert_eq!(d2, Duration::from_millis(1000));
285
286 let d3 = backoff.next_duration();
288 assert_eq!(d3, Duration::from_millis(1000));
289 }
290
291 #[rstest]
292 fn test_current_delay_getter() {
293 let initial = Duration::from_millis(100);
294 let max = Duration::from_millis(1600);
295 let factor = 2.0;
296 let jitter = 0;
297 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
298
299 assert_eq!(backoff.current_delay(), initial);
300
301 let _ = backoff.next_duration();
302 assert_eq!(backoff.current_delay(), Duration::from_millis(200));
303
304 let _ = backoff.next_duration();
305 assert_eq!(backoff.current_delay(), Duration::from_millis(400));
306
307 backoff.reset();
308 assert_eq!(backoff.current_delay(), initial);
309 }
310
311 #[rstest]
312 fn test_validation_zero_initial_delay() {
313 let result =
314 ExponentialBackoff::new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
315 assert!(result.is_err());
316 assert!(
317 result
318 .unwrap_err()
319 .to_string()
320 .contains("delay_initial must be non-zero")
321 );
322 }
323
324 #[rstest]
325 fn test_validation_max_less_than_initial() {
326 let result = ExponentialBackoff::new(
327 Duration::from_millis(1000),
328 Duration::from_millis(500),
329 2.0,
330 0,
331 false,
332 );
333 assert!(result.is_err());
334 assert!(
335 result
336 .unwrap_err()
337 .to_string()
338 .contains("delay_max must be >= delay_initial")
339 );
340 }
341
342 #[rstest]
343 fn test_validation_factor_too_small() {
344 let result = ExponentialBackoff::new(
345 Duration::from_millis(100),
346 Duration::from_millis(1000),
347 0.5,
348 0,
349 false,
350 );
351 assert!(result.is_err());
352 assert!(result.unwrap_err().to_string().contains("factor"));
353 }
354
355 #[rstest]
356 fn test_validation_factor_too_large() {
357 let result = ExponentialBackoff::new(
358 Duration::from_millis(100),
359 Duration::from_millis(1000),
360 150.0,
361 0,
362 false,
363 );
364 assert!(result.is_err());
365 assert!(result.unwrap_err().to_string().contains("factor"));
366 }
367
368 #[rstest]
369 fn test_validation_delay_max_exceeds_u64_max_nanos() {
370 let max_valid = Duration::from_nanos(u64::MAX);
373 let too_large = max_valid + Duration::from_nanos(1);
374
375 let result = ExponentialBackoff::new(Duration::from_millis(100), too_large, 2.0, 0, false);
376 assert!(result.is_err());
377 assert!(
378 result
379 .unwrap_err()
380 .to_string()
381 .contains("delay_max exceeds maximum representable duration")
382 );
383 }
384
385 #[rstest]
386 fn test_immediate_first() {
387 let initial = Duration::from_millis(100);
388 let max = Duration::from_millis(1600);
389 let factor = 2.0;
390 let jitter = 0;
391 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
392
393 let d1 = backoff.next_duration();
395 assert_eq!(
396 d1,
397 Duration::ZERO,
398 "Expected immediate reconnect (zero delay) on first call"
399 );
400
401 let d2 = backoff.next_duration();
403 assert_eq!(
404 d2, initial,
405 "Expected the delay to be the initial delay after immediate reconnect"
406 );
407
408 let d3 = backoff.next_duration();
410 let expected = initial * 2; assert_eq!(
412 d3, expected,
413 "Expected exponential growth from the initial delay"
414 );
415 }
416
417 #[rstest]
418 fn test_reset_restores_immediate_first() {
419 let initial = Duration::from_millis(100);
420 let max = Duration::from_millis(1600);
421 let factor = 2.0;
422 let jitter = 0;
423 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
424
425 let d1 = backoff.next_duration();
427 assert_eq!(d1, Duration::ZERO);
428
429 let d2 = backoff.next_duration();
431 assert_eq!(d2, initial);
432
433 backoff.reset();
435 let d3 = backoff.next_duration();
436 assert_eq!(
437 d3,
438 Duration::ZERO,
439 "Reset should restore immediate_first behavior"
440 );
441 }
442
443 #[rstest]
444 fn test_jitter_never_exceeds_max_delay() {
445 let initial = Duration::from_millis(100);
446 let max = Duration::from_millis(1000);
447 let factor = 2.0;
448 let jitter = 500;
449
450 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
451
452 while backoff.current_delay() < max {
454 backoff.next_duration();
455 }
456
457 for _ in 0..100 {
459 let delay = backoff.next_duration();
460 assert!(
461 delay <= max,
462 "Delay with jitter {delay:?} exceeded max {max:?}"
463 );
464 }
465 }
466}