1#[macro_export]
38macro_rules! approx_eq {
39 ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr) => {{
40 let left_val: $type = $left;
41 let right_val: $type = $right;
42 (left_val - right_val).abs() < $epsilon
43 }};
44 ($type:ty, $left:expr, $right:expr, epsilon = $epsilon:expr, ulps = $ulps:expr) => {{
45 let left_val: $type = $left;
46 let right_val: $type = $right;
47 (left_val - right_val).abs() < $epsilon
49 }};
50}
51
52#[inline]
61#[must_use]
62pub fn linear_weight(x1: f64, x2: f64, x: f64) -> f64 {
63 assert!(
64 x1 != x2,
65 "`x1` and `x2` must differ to compute a linear weight"
66 );
67 (x - x1) / (x2 - x1)
68}
69
70#[inline]
75#[must_use]
76pub fn linear_weighting(y1: f64, y2: f64, x1_diff: f64) -> f64 {
77 x1_diff.mul_add(y2 - y1, y1)
78}
79
80#[inline]
85#[must_use]
86pub fn pos_search(x: f64, xs: &[f64]) -> usize {
87 let n_elem = xs.len();
88 let pos = xs.partition_point(|&val| val < x);
89 std::cmp::min(std::cmp::max(pos.saturating_sub(1), 0), n_elem - 1)
90}
91
92#[inline]
103#[must_use]
104pub fn quad_polynomial(x: f64, x0: f64, x1: f64, x2: f64, y0: f64, y1: f64, y2: f64) -> f64 {
105 assert!(
107 x0 != x1 && x0 != x2 && x1 != x2,
108 "Abscissas must be distinct"
109 );
110
111 y0 * (x - x1) * (x - x2) / ((x0 - x1) * (x0 - x2))
112 + y1 * (x - x0) * (x - x2) / ((x1 - x0) * (x1 - x2))
113 + y2 * (x - x0) * (x - x1) / ((x2 - x0) * (x2 - x1))
114}
115
116#[must_use]
122pub fn quadratic_interpolation(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
123 let n_elem = xs.len();
124 let epsilon = 1e-8;
125
126 assert!(
127 n_elem >= 3,
128 "Need at least 3 points for quadratic interpolation"
129 );
130 assert_eq!(xs.len(), ys.len(), "xs and ys must have the same length");
131
132 if x <= xs[0] {
133 return ys[0];
134 }
135
136 if x >= xs[n_elem - 1] {
137 return ys[n_elem - 1];
138 }
139
140 let pos = pos_search(x, xs);
141
142 if (xs[pos] - x).abs() < epsilon {
143 return ys[pos];
144 }
145
146 if pos == 0 {
147 return quad_polynomial(x, xs[0], xs[1], xs[2], ys[0], ys[1], ys[2]);
148 }
149
150 if pos == n_elem - 2 {
151 return quad_polynomial(
152 x,
153 xs[n_elem - 3],
154 xs[n_elem - 2],
155 xs[n_elem - 1],
156 ys[n_elem - 3],
157 ys[n_elem - 2],
158 ys[n_elem - 1],
159 );
160 }
161
162 let w = linear_weight(xs[pos], xs[pos + 1], x);
163
164 linear_weighting(
165 quad_polynomial(
166 x,
167 xs[pos - 1],
168 xs[pos],
169 xs[pos + 1],
170 ys[pos - 1],
171 ys[pos],
172 ys[pos + 1],
173 ),
174 quad_polynomial(
175 x,
176 xs[pos],
177 xs[pos + 1],
178 xs[pos + 2],
179 ys[pos],
180 ys[pos + 1],
181 ys[pos + 2],
182 ),
183 w,
184 )
185}
186
187#[cfg(test)]
191mod tests {
192 use rstest::*;
193
194 use super::*;
195
196 #[rstest]
197 #[case(0.0, 10.0, 5.0, 0.5)]
198 #[case(1.0, 3.0, 2.0, 0.5)]
199 #[case(0.0, 1.0, 0.25, 0.25)]
200 #[case(0.0, 1.0, 0.75, 0.75)]
201 fn test_linear_weight_valid_cases(
202 #[case] x1: f64,
203 #[case] x2: f64,
204 #[case] x: f64,
205 #[case] expected: f64,
206 ) {
207 let result = linear_weight(x1, x2, x);
208 assert!(
209 approx_eq!(f64, result, expected, epsilon = 1e-10),
210 "Expected {expected}, got {result}"
211 );
212 }
213
214 #[rstest]
215 #[should_panic(expected = "must differ to compute a linear weight")]
216 fn test_linear_weight_zero_divisor() {
217 let _ = linear_weight(1.0, 1.0, 0.5);
218 }
219
220 #[rstest]
221 #[case(1.0, 3.0, 0.5, 2.0)]
222 #[case(10.0, 20.0, 0.25, 12.5)]
223 #[case(0.0, 10.0, 0.0, 0.0)]
224 #[case(0.0, 10.0, 1.0, 10.0)]
225 fn test_linear_weighting(
226 #[case] y1: f64,
227 #[case] y2: f64,
228 #[case] weight: f64,
229 #[case] expected: f64,
230 ) {
231 let result = linear_weighting(y1, y2, weight);
232 assert!(
233 approx_eq!(f64, result, expected, epsilon = 1e-10),
234 "Expected {expected}, got {result}"
235 );
236 }
237
238 #[rstest]
239 #[case(5.0, &[1.0, 2.0, 3.0, 4.0, 6.0, 7.0], 3)]
240 #[case(1.5, &[1.0, 2.0, 3.0, 4.0], 0)]
241 #[case(0.5, &[1.0, 2.0, 3.0, 4.0], 0)]
242 #[case(4.5, &[1.0, 2.0, 3.0, 4.0], 3)]
243 #[case(2.0, &[1.0, 2.0, 3.0, 4.0], 0)]
244 fn test_pos_search(#[case] x: f64, #[case] xs: &[f64], #[case] expected: usize) {
245 let result = pos_search(x, xs);
246 assert_eq!(result, expected);
247 }
248
249 #[rstest]
250 fn test_pos_search_edge_cases() {
251 let result = pos_search(5.0, &[10.0]);
253 assert_eq!(result, 0);
254
255 let result = pos_search(3.0, &[1.0, 2.0, 3.0, 4.0]);
257 assert_eq!(result, 1); let result = pos_search(1.5, &[1.0, 2.0]);
261 assert_eq!(result, 0);
262 }
263
264 #[rstest]
265 fn test_quad_polynomial_linear_case() {
266 let result = quad_polynomial(1.5, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0);
268 assert!(approx_eq!(f64, result, 1.5, epsilon = 1e-10));
269 }
270
271 #[rstest]
272 fn test_quad_polynomial_parabola() {
273 let result = quad_polynomial(1.5, 0.0, 1.0, 2.0, 0.0, 1.0, 4.0);
276 let expected = 1.5 * 1.5; assert!(approx_eq!(f64, result, expected, epsilon = 1e-10));
278 }
279
280 #[rstest]
281 #[should_panic(expected = "Abscissas must be distinct")]
282 fn test_quad_polynomial_duplicate_x() {
283 let _ = quad_polynomial(0.5, 1.0, 1.0, 2.0, 0.0, 1.0, 4.0);
284 }
285
286 #[rstest]
287 fn test_quadratic_interpolation_boundary_conditions() {
288 let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
289 let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; let result = quadratic_interpolation(0.5, &xs, &ys);
293 assert_eq!(result, ys[0]);
294
295 let result = quadratic_interpolation(6.0, &xs, &ys);
297 assert_eq!(result, ys[4]);
298 }
299
300 #[rstest]
301 fn test_quadratic_interpolation_exact_points() {
302 let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
303 let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0];
304
305 for (i, &x) in xs.iter().enumerate() {
307 let result = quadratic_interpolation(x, &xs, &ys);
308 assert!(approx_eq!(f64, result, ys[i], epsilon = 1e-6));
309 }
310 }
311
312 #[rstest]
313 fn test_quadratic_interpolation_intermediate_values() {
314 let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
315 let ys = vec![1.0, 4.0, 9.0, 16.0, 25.0]; let result = quadratic_interpolation(2.5, &xs, &ys);
319 let expected = 2.5 * 2.5; assert!((result - expected).abs() < 0.1); }
322
323 #[rstest]
324 #[should_panic(expected = "Need at least 3 points")]
325 fn test_quadratic_interpolation_insufficient_points() {
326 let xs = vec![1.0, 2.0];
327 let ys = vec![1.0, 4.0];
328 let _ = quadratic_interpolation(1.5, &xs, &ys);
329 }
330
331 #[rstest]
332 #[should_panic(expected = "xs and ys must have the same length")]
333 fn test_quadratic_interpolation_mismatched_lengths() {
334 let xs = vec![1.0, 2.0, 3.0];
335 let ys = vec![1.0, 4.0];
336 let _ = quadratic_interpolation(1.5, &xs, &ys);
337 }
338}