nautilus_core/
math.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Mathematical functions and interpolation utilities.
17//!
18//! This module provides essential mathematical operations for quantitative trading,
19//! including linear and quadratic interpolation functions commonly used in financial
20//! data processing and analysis.
21
22/// Macro for approximate floating-point equality comparison.
23///
24/// This macro compares two floating-point values with a specified epsilon tolerance,
25/// providing a safe alternative to exact equality checks which can fail due to
26/// floating-point precision issues.
27///
28/// # Usage
29///
30/// ```rust
31/// use nautilus_core::approx_eq;
32///
33/// let a = 0.1 + 0.2;
34/// let b = 0.3;
35/// assert!(approx_eq!(f64, a, b, epsilon = 1e-10));
36/// ```
37#[macro_export]
38macro_rules! approx_eq {
39    ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr) => {{
40        let left_val: $type = $left;
41        let right_val: $type = $right;
42        (left_val - right_val).abs() < $epsilon
43    }};
44    ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr, ulps = $ulps:expr) => {{
45        let left_val: $type = $left;
46        let right_val: $type = $right;
47        // For compatibility, we use epsilon comparison and ignore ulps
48        (left_val - right_val).abs() < $epsilon
49    }};
50}
51
52/// Calculates the interpolation weight between `x1` and `x2` for a value `x`.
53///
54/// The returned weight `w` satisfies `y = (1 - w) * y1 + w * y2` when
55/// interpolating ordinates that correspond to abscissas `x1` and `x2`.
56///
57/// # Panics
58///
59/// Panics if `x1` and `x2` are too close (within machine epsilon), which would
60/// cause division by zero or numerical instability. Uses `f64::EPSILON` * 2.0 to
61/// account for floating-point rounding in the difference computation.
62#[inline]
63#[must_use]
64pub fn linear_weight(x1: f64, x2: f64, x: f64) -> f64 {
65    const EPSILON: f64 = f64::EPSILON * 2.0; // ~4.44e-16
66    let diff = (x2 - x1).abs();
67    assert!(
68        diff >= EPSILON,
69        "`x1` ({x1}) and `x2` ({x2}) are too close for stable interpolation (diff: {diff}, min: {EPSILON})"
70    );
71    (x - x1) / (x2 - x1)
72}
73
74/// Performs linear interpolation using a weight factor.
75///
76/// Given ordinates `y1` and `y2` and a weight `x1_diff`, computes the
77/// interpolated value using the formula: `y1 + x1_diff * (y2 - y1)`.
78#[inline]
79#[must_use]
80pub fn linear_weighting(y1: f64, y2: f64, x1_diff: f64) -> f64 {
81    x1_diff.mul_add(y2 - y1, y1)
82}
83
84/// Finds the position for interpolation in a sorted array.
85///
86/// Returns the index of the largest element in `xs` that is less than `x`,
87/// clamped to the valid range `[0, xs.len() - 1]`.
88///
89/// # Edge Cases
90///
91/// - For empty arrays, returns 0
92/// - For single-element arrays, always returns index 0, regardless of whether `x > xs[0]`
93/// - For values below the minimum, returns 0
94/// - For values at or above the maximum, returns `xs.len() - 1`
95#[inline]
96#[must_use]
97pub fn pos_search(x: f64, xs: &[f64]) -> usize {
98    if xs.is_empty() {
99        return 0;
100    }
101
102    let n_elem = xs.len();
103    let pos = xs.partition_point(|&val| val < x);
104    std::cmp::min(std::cmp::max(pos.saturating_sub(1), 0), n_elem - 1)
105}
106
107/// Evaluates the quadratic Lagrange polynomial defined by three points.
108///
109/// Given points `(x0, y0)`, `(x1, y1)`, `(x2, y2)` this returns *P(x)* where
110/// *P* is the unique polynomial of degree ≤ 2 passing through the three
111/// points.
112///
113/// # Panics
114///
115/// Panics if any two abscissas are too close (within machine epsilon), which would
116/// cause division by zero or numerical instability in the interpolation.
117#[inline]
118#[must_use]
119pub fn quad_polynomial(x: f64, x0: f64, x1: f64, x2: f64, y0: f64, y1: f64, y2: f64) -> f64 {
120    const EPSILON: f64 = f64::EPSILON * 2.0; // ~4.44e-16
121
122    // Protect against coincident x values that would lead to division by zero
123    let diff_01 = (x0 - x1).abs();
124    let diff_02 = (x0 - x2).abs();
125    let diff_12 = (x1 - x2).abs();
126
127    assert!(
128        diff_01 >= EPSILON && diff_02 >= EPSILON && diff_12 >= EPSILON,
129        "Abscissas are too close for stable interpolation: x0={x0}, x1={x1}, x2={x2} (diffs: {diff_01:.2e}, {diff_02:.2e}, {diff_12:.2e}, min: {EPSILON})"
130    );
131
132    y0 * (x - x1) * (x - x2) / ((x0 - x1) * (x0 - x2))
133        + y1 * (x - x0) * (x - x2) / ((x1 - x0) * (x1 - x2))
134        + y2 * (x - x0) * (x - x1) / ((x2 - x0) * (x2 - x1))
135}
136
137/// Performs quadratic interpolation for the point `x` given vectors of abscissas `xs` and ordinates `ys`.
138///
139/// # Panics
140///
141/// Panics if `xs.len() < 3` or `xs.len() != ys.len()`.
142#[must_use]
143pub fn quadratic_interpolation(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
144    let n_elem = xs.len();
145    let epsilon = 1e-8;
146
147    assert!(
148        n_elem >= 3,
149        "Need at least 3 points for quadratic interpolation"
150    );
151    assert_eq!(xs.len(), ys.len(), "xs and ys must have the same length");
152
153    if x <= xs[0] {
154        return ys[0];
155    }
156
157    if x >= xs[n_elem - 1] {
158        return ys[n_elem - 1];
159    }
160
161    let pos = pos_search(x, xs);
162
163    if (xs[pos] - x).abs() < epsilon {
164        return ys[pos];
165    }
166
167    if pos == 0 {
168        return quad_polynomial(x, xs[0], xs[1], xs[2], ys[0], ys[1], ys[2]);
169    }
170
171    if pos == n_elem - 2 {
172        return quad_polynomial(
173            x,
174            xs[n_elem - 3],
175            xs[n_elem - 2],
176            xs[n_elem - 1],
177            ys[n_elem - 3],
178            ys[n_elem - 2],
179            ys[n_elem - 1],
180        );
181    }
182
183    let w = linear_weight(xs[pos], xs[pos + 1], x);
184
185    linear_weighting(
186        quad_polynomial(
187            x,
188            xs[pos - 1],
189            xs[pos],
190            xs[pos + 1],
191            ys[pos - 1],
192            ys[pos],
193            ys[pos + 1],
194        ),
195        quad_polynomial(
196            x,
197            xs[pos],
198            xs[pos + 1],
199            xs[pos + 2],
200            ys[pos],
201            ys[pos + 1],
202            ys[pos + 2],
203        ),
204        w,
205    )
206}
207
208#[cfg(test)]
209mod tests {
210    use rstest::*;
211
212    use super::*;
213
214    #[rstest]
215    #[case(0.0, 10.0, 5.0, 0.5)]
216    #[case(1.0, 3.0, 2.0, 0.5)]
217    #[case(0.0, 1.0, 0.25, 0.25)]
218    #[case(0.0, 1.0, 0.75, 0.75)]
219    fn test_linear_weight_valid_cases(
220        #[case] x1: f64,
221        #[case] x2: f64,
222        #[case] x: f64,
223        #[case] expected: f64,
224    ) {
225        let result = linear_weight(x1, x2, x);
226        assert!(
227            approx_eq!(f64, result, expected, epsilon = 1e-10),
228            "Expected {expected}, was {result}"
229        );
230    }
231
232    #[rstest]
233    #[should_panic(expected = "too close for stable interpolation")]
234    fn test_linear_weight_zero_divisor() {
235        let _ = linear_weight(1.0, 1.0, 0.5);
236    }
237
238    #[rstest]
239    #[should_panic(expected = "too close for stable interpolation")]
240    fn test_linear_weight_near_equal_values() {
241        // Values within machine epsilon should be rejected
242        let _ = linear_weight(1.0, 1.0 + f64::EPSILON, 0.5);
243    }
244
245    #[rstest]
246    fn test_linear_weight_with_small_differences() {
247        // High-resolution data (e.g., nanosecond timestamps as seconds) should work
248        let result = linear_weight(0.0, 1e-12, 5e-13);
249        assert!(result.is_finite());
250        assert!((result - 0.5).abs() < 1e-10); // Should be approximately 0.5
251    }
252
253    #[rstest]
254    fn test_linear_weight_just_above_epsilon() {
255        // Values differing by more than machine epsilon should work
256        let result = linear_weight(1.0, 1.0 + 1e-9, 1.0 + 5e-10);
257        // Should not panic and return a reasonable value
258        assert!(result.is_finite());
259    }
260
261    #[rstest]
262    #[case(1.0, 3.0, 0.5, 2.0)]
263    #[case(10.0, 20.0, 0.25, 12.5)]
264    #[case(0.0, 10.0, 0.0, 0.0)]
265    #[case(0.0, 10.0, 1.0, 10.0)]
266    fn test_linear_weighting(
267        #[case] y1: f64,
268        #[case] y2: f64,
269        #[case] weight: f64,
270        #[case] expected: f64,
271    ) {
272        let result = linear_weighting(y1, y2, weight);
273        assert!(
274            approx_eq!(f64, result, expected, epsilon = 1e-10),
275            "Expected {expected}, was {result}"
276        );
277    }
278
279    #[rstest]
280    #[case(5.0, &[1.0, 2.0, 3.0, 4.0, 6.0, 7.0], 3)]
281    #[case(1.5, &[1.0, 2.0, 3.0, 4.0], 0)]
282    #[case(0.5, &[1.0, 2.0, 3.0, 4.0], 0)]
283    #[case(4.5, &[1.0, 2.0, 3.0, 4.0], 3)]
284    #[case(2.0, &[1.0, 2.0, 3.0, 4.0], 0)]
285    fn test_pos_search(#[case] x: f64, #[case] xs: &[f64], #[case] expected: usize) {
286        let result = pos_search(x, xs);
287        assert_eq!(result, expected);
288    }
289
290    #[rstest]
291    fn test_pos_search_edge_cases() {
292        // Single element array
293        let result = pos_search(5.0, &[10.0]);
294        assert_eq!(result, 0);
295
296        // Value at exact boundary
297        let result = pos_search(3.0, &[1.0, 2.0, 3.0, 4.0]);
298        assert_eq!(result, 1); // Index of largest element < 3.0 is index 1 (value 2.0)
299
300        // Two element array
301        let result = pos_search(1.5, &[1.0, 2.0]);
302        assert_eq!(result, 0);
303    }
304
305    #[rstest]
306    fn test_pos_search_empty_slice() {
307        let empty: &[f64] = &[];
308        assert_eq!(pos_search(42.0, empty), 0);
309    }
310
311    #[rstest]
312    fn test_quad_polynomial_linear_case() {
313        // Test with three collinear points - should behave like linear interpolation
314        let result = quad_polynomial(1.5, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0);
315        assert!(approx_eq!(f64, result, 1.5, epsilon = 1e-10));
316    }
317
318    #[rstest]
319    fn test_quad_polynomial_parabola() {
320        // Test with a simple parabola y = x^2
321        // Points: (0,0), (1,1), (2,4)
322        let result = quad_polynomial(1.5, 0.0, 1.0, 2.0, 0.0, 1.0, 4.0);
323        let expected = 1.5 * 1.5; // Should be 2.25
324        assert!(approx_eq!(f64, result, expected, epsilon = 1e-10));
325    }
326
327    #[rstest]
328    #[should_panic(expected = "too close for stable interpolation")]
329    fn test_quad_polynomial_duplicate_x() {
330        let _ = quad_polynomial(0.5, 1.0, 1.0, 2.0, 0.0, 1.0, 4.0);
331    }
332
333    #[rstest]
334    #[should_panic(expected = "too close for stable interpolation")]
335    fn test_quad_polynomial_near_equal_x_values() {
336        // x0 and x1 differ by only machine epsilon
337        let _ = quad_polynomial(0.5, 1.0, 1.0 + f64::EPSILON, 2.0, 0.0, 1.0, 4.0);
338    }
339
340    #[rstest]
341    fn test_quad_polynomial_with_small_differences() {
342        // High-resolution data should work (e.g., 1e-12 spacing)
343        let result = quad_polynomial(5e-13, 0.0, 1e-12, 2e-12, 0.0, 1.0, 4.0);
344        assert!(result.is_finite());
345    }
346
347    #[rstest]
348    fn test_quad_polynomial_just_above_epsilon() {
349        // Values differing by more than machine epsilon should work
350        let result = quad_polynomial(0.5, 0.0, 1.0 + 1e-9, 2.0, 0.0, 1.0, 4.0);
351        // Should not panic and return a reasonable value
352        assert!(result.is_finite());
353    }
354
355    #[rstest]
356    fn test_quadratic_interpolation_boundary_conditions() {
357        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; // y = x^2
359
360        // Test below minimum
361        let result = quadratic_interpolation(0.5, &xs, &ys);
362        assert_eq!(result, ys[0]);
363
364        // Test above maximum
365        let result = quadratic_interpolation(6.0, &xs, &ys);
366        assert_eq!(result, ys[4]);
367    }
368
369    #[rstest]
370    fn test_quadratic_interpolation_exact_points() {
371        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
372        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0];
373
374        // Test exact points
375        for (i, &x) in xs.iter().enumerate() {
376            let result = quadratic_interpolation(x, &xs, &ys);
377            assert!(approx_eq!(f64, result, ys[i], epsilon = 1e-6));
378        }
379    }
380
381    #[rstest]
382    fn test_quadratic_interpolation_intermediate_values() {
383        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
384        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; // y = x^2
385
386        // Test interpolation between points
387        let result = quadratic_interpolation(2.5, &xs, &ys);
388        let expected = 2.5 * 2.5; // Should be close to 6.25
389        assert!((result - expected).abs() < 0.1); // Allow some interpolation error
390    }
391
392    #[rstest]
393    #[should_panic(expected = "Need at least 3 points")]
394    fn test_quadratic_interpolation_insufficient_points() {
395        let xs = vec![1.0, 2.0];
396        let ys = vec![1.0, 4.0];
397        let _ = quadratic_interpolation(1.5, &xs, &ys);
398    }
399
400    #[rstest]
401    #[should_panic(expected = "xs and ys must have the same length")]
402    fn test_quadratic_interpolation_mismatched_lengths() {
403        let xs = vec![1.0, 2.0, 3.0];
404        let ys = vec![1.0, 4.0];
405        let _ = quadratic_interpolation(1.5, &xs, &ys);
406    }
407}