nautilus_core/
math.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Mathematical functions and interpolation utilities.
17//!
18//! This module provides essential mathematical operations for quantitative trading,
19//! including linear and quadratic interpolation functions commonly used in financial
20//! data processing and analysis.
21
22/// Macro for approximate floating-point equality comparison.
23///
24/// This macro compares two floating-point values with a specified epsilon tolerance,
25/// providing a safe alternative to exact equality checks which can fail due to
26/// floating-point precision issues.
27///
28/// # Usage
29///
30/// ```rust
31/// use nautilus_core::approx_eq;
32///
33/// let a = 0.1 + 0.2;
34/// let b = 0.3;
35/// assert!(approx_eq!(f64, a, b, epsilon = 1e-10));
36/// ```
37#[macro_export]
38macro_rules! approx_eq {
39    ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr) => {{
40        let left_val: $type = $left;
41        let right_val: $type = $right;
42        (left_val - right_val).abs() < $epsilon
43    }};
44    ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr, ulps = $ulps:expr) => {{
45        let left_val: $type = $left;
46        let right_val: $type = $right;
47        // For compatibility, we use epsilon comparison and ignore ulps
48        (left_val - right_val).abs() < $epsilon
49    }};
50}
51
52/// Calculates the interpolation weight between `x1` and `x2` for a value `x`.
53///
54/// The returned weight `w` satisfies `y = (1 - w) * y1 + w * y2` when
55/// interpolating ordinates that correspond to abscissas `x1` and `x2`.
56///
57/// # Panics
58///
59/// Panics if `x1 == x2` because the denominator becomes zero.
60#[inline]
61#[must_use]
62pub fn linear_weight(x1: f64, x2: f64, x: f64) -> f64 {
63    assert!(
64        x1 != x2,
65        "`x1` and `x2` must differ to compute a linear weight"
66    );
67    (x - x1) / (x2 - x1)
68}
69
70/// Performs linear interpolation using a weight factor.
71///
72/// Given ordinates `y1` and `y2` and a weight `x1_diff`, computes the
73/// interpolated value using the formula: `y1 + x1_diff * (y2 - y1)`.
74#[inline]
75#[must_use]
76pub fn linear_weighting(y1: f64, y2: f64, x1_diff: f64) -> f64 {
77    x1_diff.mul_add(y2 - y1, y1)
78}
79
80/// Finds the position for interpolation in a sorted array.
81///
82/// Returns the index of the largest element in `xs` that is less than `x`,
83/// clamped to the valid range `[0, xs.len() - 1]`.
84#[inline]
85#[must_use]
86pub fn pos_search(x: f64, xs: &[f64]) -> usize {
87    let n_elem = xs.len();
88    let pos = xs.partition_point(|&val| val < x);
89    std::cmp::min(std::cmp::max(pos.saturating_sub(1), 0), n_elem - 1)
90}
91
92/// Evaluates the quadratic Lagrange polynomial defined by three points.
93///
94/// Given points `(x0, y0)`, `(x1, y1)`, `(x2, y2)` this returns *P(x)* where
95/// *P* is the unique polynomial of degree ≤ 2 passing through the three
96/// points.
97///
98/// # Panics
99///
100/// Panics if any two abscissas are identical because the interpolation
101/// coefficients would involve division by zero.
102#[inline]
103#[must_use]
104pub fn quad_polynomial(x: f64, x0: f64, x1: f64, x2: f64, y0: f64, y1: f64, y2: f64) -> f64 {
105    // Protect against coincident x values that would lead to division by zero
106    assert!(
107        x0 != x1 && x0 != x2 && x1 != x2,
108        "Abscissas must be distinct"
109    );
110
111    y0 * (x - x1) * (x - x2) / ((x0 - x1) * (x0 - x2))
112        + y1 * (x - x0) * (x - x2) / ((x1 - x0) * (x1 - x2))
113        + y2 * (x - x0) * (x - x1) / ((x2 - x0) * (x2 - x1))
114}
115
116/// Performs quadratic interpolation for the point `x` given vectors of abscissas `xs` and ordinates `ys`.
117///
118/// # Panics
119///
120/// Panics if `xs.len() < 3` or `xs.len() != ys.len()`.
121#[must_use]
122pub fn quadratic_interpolation(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
123    let n_elem = xs.len();
124    let epsilon = 1e-8;
125
126    assert!(
127        n_elem >= 3,
128        "Need at least 3 points for quadratic interpolation"
129    );
130    assert_eq!(xs.len(), ys.len(), "xs and ys must have the same length");
131
132    if x <= xs[0] {
133        return ys[0];
134    }
135
136    if x >= xs[n_elem - 1] {
137        return ys[n_elem - 1];
138    }
139
140    let pos = pos_search(x, xs);
141
142    if (xs[pos] - x).abs() < epsilon {
143        return ys[pos];
144    }
145
146    if pos == 0 {
147        return quad_polynomial(x, xs[0], xs[1], xs[2], ys[0], ys[1], ys[2]);
148    }
149
150    if pos == n_elem - 2 {
151        return quad_polynomial(
152            x,
153            xs[n_elem - 3],
154            xs[n_elem - 2],
155            xs[n_elem - 1],
156            ys[n_elem - 3],
157            ys[n_elem - 2],
158            ys[n_elem - 1],
159        );
160    }
161
162    let w = linear_weight(xs[pos], xs[pos + 1], x);
163
164    linear_weighting(
165        quad_polynomial(
166            x,
167            xs[pos - 1],
168            xs[pos],
169            xs[pos + 1],
170            ys[pos - 1],
171            ys[pos],
172            ys[pos + 1],
173        ),
174        quad_polynomial(
175            x,
176            xs[pos],
177            xs[pos + 1],
178            xs[pos + 2],
179            ys[pos],
180            ys[pos + 1],
181            ys[pos + 2],
182        ),
183        w,
184    )
185}
186
187////////////////////////////////////////////////////////////////////////////////
188// Tests
189////////////////////////////////////////////////////////////////////////////////
190#[cfg(test)]
191mod tests {
192    use rstest::*;
193
194    use super::*;
195
196    #[rstest]
197    #[case(0.0, 10.0, 5.0, 0.5)]
198    #[case(1.0, 3.0, 2.0, 0.5)]
199    #[case(0.0, 1.0, 0.25, 0.25)]
200    #[case(0.0, 1.0, 0.75, 0.75)]
201    fn test_linear_weight_valid_cases(
202        #[case] x1: f64,
203        #[case] x2: f64,
204        #[case] x: f64,
205        #[case] expected: f64,
206    ) {
207        let result = linear_weight(x1, x2, x);
208        assert!(
209            approx_eq!(f64, result, expected, epsilon = 1e-10),
210            "Expected {expected}, got {result}"
211        );
212    }
213
214    #[rstest]
215    #[should_panic(expected = "must differ to compute a linear weight")]
216    fn test_linear_weight_zero_divisor() {
217        let _ = linear_weight(1.0, 1.0, 0.5);
218    }
219
220    #[rstest]
221    #[case(1.0, 3.0, 0.5, 2.0)]
222    #[case(10.0, 20.0, 0.25, 12.5)]
223    #[case(0.0, 10.0, 0.0, 0.0)]
224    #[case(0.0, 10.0, 1.0, 10.0)]
225    fn test_linear_weighting(
226        #[case] y1: f64,
227        #[case] y2: f64,
228        #[case] weight: f64,
229        #[case] expected: f64,
230    ) {
231        let result = linear_weighting(y1, y2, weight);
232        assert!(
233            approx_eq!(f64, result, expected, epsilon = 1e-10),
234            "Expected {expected}, got {result}"
235        );
236    }
237
238    #[rstest]
239    #[case(5.0, &[1.0, 2.0, 3.0, 4.0, 6.0, 7.0], 3)]
240    #[case(1.5, &[1.0, 2.0, 3.0, 4.0], 0)]
241    #[case(0.5, &[1.0, 2.0, 3.0, 4.0], 0)]
242    #[case(4.5, &[1.0, 2.0, 3.0, 4.0], 3)]
243    #[case(2.0, &[1.0, 2.0, 3.0, 4.0], 0)]
244    fn test_pos_search(#[case] x: f64, #[case] xs: &[f64], #[case] expected: usize) {
245        let result = pos_search(x, xs);
246        assert_eq!(result, expected);
247    }
248
249    #[rstest]
250    fn test_pos_search_edge_cases() {
251        // Single element array
252        let result = pos_search(5.0, &[10.0]);
253        assert_eq!(result, 0);
254
255        // Value at exact boundary
256        let result = pos_search(3.0, &[1.0, 2.0, 3.0, 4.0]);
257        assert_eq!(result, 1); // Index of largest element < 3.0 is index 1 (value 2.0)
258
259        // Two element array
260        let result = pos_search(1.5, &[1.0, 2.0]);
261        assert_eq!(result, 0);
262    }
263
264    #[rstest]
265    fn test_quad_polynomial_linear_case() {
266        // Test with three collinear points - should behave like linear interpolation
267        let result = quad_polynomial(1.5, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0);
268        assert!(approx_eq!(f64, result, 1.5, epsilon = 1e-10));
269    }
270
271    #[rstest]
272    fn test_quad_polynomial_parabola() {
273        // Test with a simple parabola y = x^2
274        // Points: (0,0), (1,1), (2,4)
275        let result = quad_polynomial(1.5, 0.0, 1.0, 2.0, 0.0, 1.0, 4.0);
276        let expected = 1.5 * 1.5; // Should be 2.25
277        assert!(approx_eq!(f64, result, expected, epsilon = 1e-10));
278    }
279
280    #[rstest]
281    #[should_panic(expected = "Abscissas must be distinct")]
282    fn test_quad_polynomial_duplicate_x() {
283        let _ = quad_polynomial(0.5, 1.0, 1.0, 2.0, 0.0, 1.0, 4.0);
284    }
285
286    #[rstest]
287    fn test_quadratic_interpolation_boundary_conditions() {
288        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
289        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; // y = x^2
290
291        // Test below minimum
292        let result = quadratic_interpolation(0.5, &xs, &ys);
293        assert_eq!(result, ys[0]);
294
295        // Test above maximum
296        let result = quadratic_interpolation(6.0, &xs, &ys);
297        assert_eq!(result, ys[4]);
298    }
299
300    #[rstest]
301    fn test_quadratic_interpolation_exact_points() {
302        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
303        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0];
304
305        // Test exact points
306        for (i, &x) in xs.iter().enumerate() {
307            let result = quadratic_interpolation(x, &xs, &ys);
308            assert!(approx_eq!(f64, result, ys[i], epsilon = 1e-6));
309        }
310    }
311
312    #[rstest]
313    fn test_quadratic_interpolation_intermediate_values() {
314        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
315        let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; // y = x^2
316
317        // Test interpolation between points
318        let result = quadratic_interpolation(2.5, &xs, &ys);
319        let expected = 2.5 * 2.5; // Should be close to 6.25
320        assert!((result - expected).abs() < 0.1); // Allow some interpolation error
321    }
322
323    #[rstest]
324    #[should_panic(expected = "Need at least 3 points")]
325    fn test_quadratic_interpolation_insufficient_points() {
326        let xs = vec![1.0, 2.0];
327        let ys = vec![1.0, 4.0];
328        let _ = quadratic_interpolation(1.5, &xs, &ys);
329    }
330
331    #[rstest]
332    #[should_panic(expected = "xs and ys must have the same length")]
333    fn test_quadratic_interpolation_mismatched_lengths() {
334        let xs = vec![1.0, 2.0, 3.0];
335        let ys = vec![1.0, 4.0];
336        let _ = quadratic_interpolation(1.5, &xs, &ys);
337    }
338}