nautilus_model/defi/tick_map/
full_math.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use alloy_primitives::{I256, U160, U256};
17
18pub const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
19pub const Q96_U160: U160 = U160::from_limbs([0, 1 << 32, 0]);
20
21/// Contains 512-bit math functions for Uniswap V3 style calculations
22/// Handles "phantom overflow" - allows multiplication and division where
23/// intermediate values overflow 256 bits
24#[derive(Debug)]
25pub struct FullMath;
26
27impl FullMath {
28    /// Calculates floor(a×b÷denominator) with full precision
29    ///
30    /// # Errors
31    ///
32    /// Returns error if `denominator` is zero or the result would overflow 256 bits.
33    pub fn mul_div(a: U256, b: U256, denominator: U256) -> anyhow::Result<U256> {
34        if denominator.is_zero() {
35            anyhow::bail!("Cannot divide by zero");
36        }
37
38        // Compute the 512-bit product [prod1,prod2] = a * b
39        // prod0 - least significant  256 bit
40        // prod1 - most significant 256 bit
41        let (prod0, overflow) = a.overflowing_mul(b);
42
43        // Calculate prod1 using mulmod equivalent
44        // prod1 = (a * b - prod0) / 2^256
45        // We need to handle the high part of multiplication
46        let prod1 = if overflow {
47            // When overflow occurs, we need the high 256 bits
48            // This is equivalent to: mulmod(a, b, 2^256) but for the high part
49            Self::mul_high(a, b)
50        } else {
51            U256::ZERO
52        };
53
54        // Handle non-overflow cases, 256 by 256 division
55        if prod1.is_zero() {
56            return Ok(prod0 / denominator);
57        }
58
59        if denominator <= prod1 {
60            anyhow::bail!("Result would overflow 256 bits");
61        }
62
63        // 512 by 256 division
64
65        // Make division exact by subtracting the remainder from [prod1 prod0]
66        // Compute remainder using modular multiplication
67        let remainder = Self::mulmod(a, b, denominator);
68
69        // Subtract 256 bit number from 512 bit number
70        let (prod0, borrow) = prod0.overflowing_sub(remainder);
71        let prod1 = if borrow { prod1 - U256::from(1) } else { prod1 };
72
73        // Factor powers of two out of denominator
74        // Compute largest power of two divisor of denominator (always >= 1)
75        let twos = (!denominator).wrapping_add(U256::from(1)) & denominator;
76        let denominator = denominator / twos;
77
78        // Divide [prod1 prod0] by the factors of two
79        let prod0 = prod0 / twos;
80
81        // Shift in bits from prod1 into prod0
82        // We need to flip `twos` such that it is 2^256 / twos
83        let twos = Self::div_2_256_by(twos);
84        let prod0 = prod0 | (prod1 * twos);
85
86        // Invert denominator mod 2^256
87        // Now that denominator is an odd number, it has an inverse
88        // modulo 2^256 such that denominator * inv = 1 mod 2^256
89        let inv = Self::mod_inverse(denominator);
90
91        // Because the division is now exact we can divide by multiplying
92        // with the modular inverse of denominator
93        Ok(prod0.wrapping_mul(inv))
94    }
95
96    /// Calculates ceil(a×b÷denominator) with full precision
97    /// Returns `Ok` with the rounded result or an error when rounding cannot be performed safely.
98    ///
99    /// # Errors
100    ///
101    /// Returns error if `denominator` is zero or the rounded result would overflow `U256`.
102    pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> anyhow::Result<U256> {
103        let result = Self::mul_div(a, b, denominator)?;
104
105        // Check if there's a remainder
106        if !Self::mulmod(a, b, denominator).is_zero() {
107            // Check for overflow before incrementing
108            if result == U256::MAX {
109                anyhow::bail!("Result would overflow 256 bits");
110            }
111            Ok(result + U256::from(1))
112        } else {
113            Ok(result)
114        }
115    }
116
117    /// Calculates ceil(a÷b) with proper rounding up
118    /// Equivalent to Solidity's divRoundingUp function
119    ///
120    /// # Errors
121    ///
122    /// Returns error if `b` is zero or if the rounded quotient would overflow `U256`.
123    pub fn div_rounding_up(a: U256, b: U256) -> anyhow::Result<U256> {
124        if b.is_zero() {
125            anyhow::bail!("Cannot divide by zero");
126        }
127
128        let quotient = a / b;
129        let remainder = a % b;
130
131        // Add 1 if there's a remainder (equivalent to gt(mod(x, y), 0) in assembly)
132        if remainder > U256::ZERO {
133            // Check for overflow before incrementing
134            if quotient == U256::MAX {
135                anyhow::bail!("Result would overflow 256 bits");
136            }
137            Ok(quotient + U256::from(1))
138        } else {
139            Ok(quotient)
140        }
141    }
142
143    /// Computes modular multiplicative inverse using Newton-Raphson iteration
144    fn mod_inverse(denominator: U256) -> U256 {
145        // Start with a seed that is correct for four bits
146        // That is, denominator * inv = 1 mod 2^4
147        let mut inv = (U256::from(3) * denominator) ^ U256::from(2);
148
149        // Use Newton-Raphson iteration to improve precision
150        // Thanks to Hensel's lifting lemma, this works in modular arithmetic,
151        // doubling the correct bits in each step
152
153        // inverse mod 2^8
154        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
155        // inverse mod 2^16
156        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
157        // inverse mod 2^32
158        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
159        // inverse mod 2^64
160        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
161        // inverse mod 2^128
162        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
163        // inverse mod 2^256
164        inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
165
166        inv
167    }
168
169    /// Computes 2^256 / x (assuming x is a power of 2)
170    fn div_2_256_by(x: U256) -> U256 {
171        if x.is_zero() {
172            return U256::from(1);
173        }
174
175        // For a power of 2, we can use bit manipulation
176        // Count trailing zeros to find the power
177        let trailing_zeros = x.trailing_zeros();
178
179        if trailing_zeros >= 256 {
180            U256::from(1)
181        } else {
182            U256::from(1) << (256 - trailing_zeros)
183        }
184    }
185
186    /// Computes (a * b) mod m
187    fn mulmod(a: U256, b: U256, m: U256) -> U256 {
188        if m.is_zero() {
189            return U256::ZERO;
190        }
191
192        // For small values, we can use simple approach
193        if a < U256::from(u128::MAX) && b < U256::from(u128::MAX) {
194            return (a * b) % m;
195        }
196
197        let (low, overflow) = a.overflowing_mul(b);
198        if !overflow {
199            return low % m;
200        }
201
202        // When overflow occurs, use bit-by-bit modular multiplication
203        // This is slower but handles all cases correctly
204        Self::mulmod_slow(a, b, m)
205    }
206
207    /// Slow but correct modular multiplication for large numbers
208    fn mulmod_slow(mut a: U256, mut b: U256, m: U256) -> U256 {
209        let mut result = U256::ZERO;
210        a %= m;
211
212        while b > U256::ZERO {
213            if b & U256::from(1) == U256::from(1) {
214                result = (result + a) % m;
215            }
216            a = (a * U256::from(2)) % m;
217            b >>= 1
218        }
219
220        result
221    }
222
223    // Computes the high 256 bits of a * b
224    fn mul_high(a: U256, b: U256) -> U256 {
225        // Split each number into high and low 128-bit parts
226        let a_low = a & U256::from(u128::MAX);
227        let a_high = a >> 128;
228        let b_low = b & U256::from(u128::MAX);
229        let b_high = b >> 128;
230
231        // Compute partial products
232        let ll = a_low * b_low;
233        let lh = a_low * b_high;
234        let hl = a_high * b_low;
235        let hh = a_high * b_high;
236
237        let mid_sum = lh + hl;
238        let mid_high = mid_sum >> 128;
239        let mid_low = mid_sum << 128;
240
241        // Check for carry from the low addition
242        let (_, carry) = ll.overflowing_add(mid_low);
243
244        hh + mid_high + if carry { U256::from(1) } else { U256::ZERO }
245    }
246
247    /// Computes the integer square root of a 256-bit unsigned integer using the Babylonian method
248    pub fn sqrt(x: U256) -> U256 {
249        if x.is_zero() {
250            return U256::ZERO;
251        }
252        if x == U256::from(1u128) {
253            return U256::from(1u128);
254        }
255
256        let mut z = x;
257        let mut y = (x + U256::from(1u128)) >> 1;
258
259        while y < z {
260            z = y;
261            y = (x / z + z) >> 1;
262        }
263
264        z
265    }
266
267    /// Truncates a U256 value to u128 by extracting the lower 128 bits.
268    ///
269    /// This matches Solidity's `uint128(value)` cast behavior, which discards
270    /// the upper 128 bits. If the value is larger than `u128::MAX`, the upper
271    /// bits are lost.
272    #[must_use]
273    pub fn truncate_to_u128(value: U256) -> u128 {
274        (value & U256::from(u128::MAX)).to::<u128>()
275    }
276
277    /// Converts an I256 signed integer to U256, mimicking Solidity's `uint256(int256)` cast.
278    ///
279    /// This performs a reinterpret cast, preserving the bit pattern:
280    /// - Positive values: returns the value as-is
281    /// - Negative values: returns the two's complement representation as unsigned
282    #[must_use]
283    pub fn truncate_to_u256(value: I256) -> U256 {
284        value.into_raw()
285    }
286
287    /// Converts a U256 unsigned integer to I256, mimicking Solidity's `int256(uint256)` cast.
288    ///
289    /// This performs a reinterpret cast, preserving the bit pattern.
290    /// Solidity's SafeCast.toInt256() checks the value fits in I256::MAX, then reinterprets.
291    ///
292    /// # Panics
293    /// Panics if the value exceeds I256::MAX (matching Solidity's require check)
294    #[must_use]
295    pub fn truncate_to_i256(value: U256) -> I256 {
296        I256::from_raw(value)
297    }
298}
299
300////////////////////////////////////////////////////////////////////////////////
301// Tests
302////////////////////////////////////////////////////////////////////////////////
303
304#[cfg(test)]
305mod tests {
306    use rstest::*;
307
308    use super::*;
309
310    #[rstest]
311    fn test_mul_high_basic() {
312        // Test 0 * 0 = 0
313        assert_eq!(FullMath::mul_high(U256::ZERO, U256::ZERO), U256::ZERO);
314
315        // Test 1 * 1 = 1, high bits should be 0
316        assert_eq!(FullMath::mul_high(U256::from(1), U256::from(1)), U256::ZERO);
317
318        // Test MAX * 1 = MAX, high bits should be 0
319        assert_eq!(FullMath::mul_high(U256::MAX, U256::from(1)), U256::ZERO);
320    }
321
322    #[rstest]
323    fn test_mul_high_simple_case() {
324        // Test 2^128 * 2^128 = 2^256
325        // This should give us 1 in the high bits (bit 256 set)
326        let result = FullMath::mul_high(Q128, Q128);
327        assert_eq!(result, U256::from(1));
328    }
329
330    #[rstest]
331    fn test_mul_high_asymmetric() {
332        // Test large * small
333        let large = U256::MAX;
334        let small = U256::from(2);
335        let result = FullMath::mul_high(large, small);
336        // MAX * 2 = 2 * (2^256 - 1) = 2^257 - 2
337        // High bits should be 1
338        assert_eq!(result, U256::from(1));
339
340        // 2^200 * 2^100 = 2^300, high part should be 2^44
341        let a = U256::from(1u128) << 200;
342        let b = U256::from(1u128) << 100;
343        let expected = U256::from(1u128) << 44;
344        assert_eq!(FullMath::mul_high(a, b), expected);
345    }
346
347    #[rstest]
348    fn test_mul_high_known_values() {
349        // Test with known 128-bit values
350        let a = U256::from(u128::MAX); // 2^128 - 1
351        let b = U256::from(u128::MAX); // 2^128 - 1
352        let result = FullMath::mul_high(a, b);
353        // (2^128 - 1)^2 = 2^256 - 2^129 + 1
354        // High 256 bits should be 0 (since result < 2^256)
355        assert_eq!(result, U256::ZERO);
356    }
357
358    #[rstest]
359    fn test_mul_high_carry_propagation() {
360        // Test cases where carry propagation is critical
361
362        // Test with values that cause carry in mid_sum
363        let a = U256::from_str_radix(
364            "ffffffffffffffffffffffffffffffff00000000000000000000000000000000",
365            16,
366        )
367        .unwrap();
368        let b = U256::from_str_radix(
369            "ffffffffffffffffffffffffffffffff00000000000000000000000000000000",
370            16,
371        )
372        .unwrap();
373
374        let result = FullMath::mul_high(a, b);
375
376        // Verify against expected value
377        // (2^128 - 1)^2 * 2^256 = 2^512 - 2^257 + 2^256
378        // High part should be 2^256 - 2 + 1 = 2^256 - 1
379        let expected = U256::from_str_radix(
380            "fffffffffffffffffffffffffffffffe00000000000000000000000000000001",
381            16,
382        )
383        .unwrap();
384        assert_eq!(result, expected);
385    }
386
387    #[rstest]
388    fn test_mul_high_symmetry() {
389        // Test that mul_high(a, b) == mul_high(b, a)
390        let a = U256::from_str_radix("123456789abcdef0123456789abcdef0", 16).unwrap();
391        let b = U256::from_str_radix("fedcba9876543210fedcba9876543210", 16).unwrap();
392
393        assert_eq!(FullMath::mul_high(a, b), FullMath::mul_high(b, a));
394    }
395
396    #[rstest]
397    fn test_mulmod_basic() {
398        // Test basic cases
399        assert_eq!(
400            FullMath::mulmod(U256::ZERO, U256::from(5), U256::from(3)),
401            U256::ZERO
402        );
403        assert_eq!(
404            FullMath::mulmod(U256::from(5), U256::ZERO, U256::from(3)),
405            U256::ZERO
406        );
407        assert_eq!(
408            FullMath::mulmod(U256::from(5), U256::from(3), U256::ZERO),
409            U256::ZERO
410        );
411
412        // Simple multiplication: 5 * 3 mod 7 = 15 mod 7 = 1
413        assert_eq!(
414            FullMath::mulmod(U256::from(5), U256::from(3), U256::from(7)),
415            U256::from(1)
416        );
417
418        // Test where result equals modulus: 6 * 2 mod 12 = 0
419        assert_eq!(
420            FullMath::mulmod(U256::from(6), U256::from(2), U256::from(12)),
421            U256::ZERO
422        );
423    }
424
425    #[rstest]
426    fn test_mulmod_small_values() {
427        // Test the fast path for small values
428        let a = U256::from(123456u64);
429        let b = U256::from(789012u64);
430        let m = U256::from(100000u64);
431
432        // 123456 * 789012 = 97408265472
433        // 97408265472 mod 100000 = 65472
434        assert_eq!(FullMath::mulmod(a, b, m), U256::from(65472));
435    }
436
437    #[rstest]
438    fn test_mulmod_no_overflow() {
439        // Test medium-sized values that don't overflow
440        let a = U256::from(u64::MAX);
441        let b = U256::from(1000u64);
442        let m = U256::from(u64::MAX);
443
444        let result = FullMath::mulmod(a, b, m);
445        let expected = (U256::from(u64::MAX) * U256::from(1000)) % U256::from(u64::MAX);
446        assert_eq!(result, expected);
447    }
448
449    #[rstest]
450    fn test_mulmod_large_overflow() {
451        // Test with values that cause overflow - forces use of slow path
452        let a = U256::MAX;
453        let b = U256::MAX;
454        let m = U256::from(1000000007u64); // Large prime
455
456        let result = FullMath::mulmod(a, b, m);
457
458        // Since we can't easily compute MAX * MAX mod m by hand,
459        // let's verify by testing the modular arithmetic property:
460        // (a mod m) * (b mod m) mod m should equal our result for small enough modulus
461        let a_mod = a % m;
462        let b_mod = b % m;
463        let expected = (a_mod * b_mod) % m;
464        assert_eq!(result, expected);
465    }
466
467    #[rstest]
468    fn test_mulmod_symmetry() {
469        // Test that mulmod(a, b, m) == mulmod(b, a, m)
470        let a = U256::from_str_radix("123456789abcdef0", 16).unwrap();
471        let b = U256::from_str_radix("fedcba9876543210", 16).unwrap();
472        let m = U256::from(1000000007u64);
473
474        assert_eq!(FullMath::mulmod(a, b, m), FullMath::mulmod(b, a, m));
475    }
476
477    #[rstest]
478    fn test_mulmod_identity() {
479        // Test multiplicative identity: a * 1 mod m = a mod m
480        let a = U256::from_str_radix("123456789abcdef0123456789abcdef0", 16).unwrap();
481        let m = U256::from(1000000007u64);
482
483        assert_eq!(FullMath::mulmod(a, U256::from(1), m), a % m);
484        assert_eq!(FullMath::mulmod(U256::from(1), a, m), a % m);
485    }
486
487    #[rstest]
488    fn test_mulmod_powers_of_two() {
489        // Test with powers of 2 for easier verification
490        let a = U256::from(1) << 100; // 2^100
491        let b = U256::from(1) << 50; // 2^50
492        let m = U256::from(1) << 60; // 2^60
493
494        // 2^100 * 2^50 = 2^150
495        // 2^150 mod 2^60 = 0 (since 150 > 60)
496        assert_eq!(FullMath::mulmod(a, b, m), U256::ZERO);
497
498        // Test case where result is non-zero
499        let a2 = U256::from(1) << 30; // 2^30
500        let b2 = U256::from(1) << 20; // 2^20
501        let m2 = U256::from(1) << 60; // 2^60
502
503        // 2^30 * 2^20 = 2^50, which is < 2^60
504        let expected = U256::from(1) << 50;
505        assert_eq!(FullMath::mulmod(a2, b2, m2), expected);
506    }
507
508    #[rstest]
509    fn test_mul_div_reverts_denominator_zero() {
510        // Test that denominator 0 causes error
511        assert!(FullMath::mul_div(Q128, U256::from(5), U256::ZERO).is_err());
512
513        // Test with numerator overflow and denominator 0
514        assert!(FullMath::mul_div(Q128, Q128, U256::ZERO).is_err());
515    }
516
517    #[rstest]
518    fn test_mul_div_reverts_output_overflow() {
519        // Test output overflow: Q128 * Q128 / 1 would overflow
520        assert!(FullMath::mul_div(Q128, Q128, U256::from(1)).is_err());
521
522        // Test overflow with inputs that would cause prod1 >= denominator
523        // MAX * MAX / 1 would definitely overflow
524        assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(1)).is_err());
525
526        // Test with a smaller denominator that should still cause overflow
527        assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(2)).is_err());
528    }
529
530    #[rstest]
531    fn test_mul_div_all_max_inputs() {
532        // MAX * MAX / MAX = MAX
533        let result = FullMath::mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
534        assert_eq!(result, U256::MAX);
535    }
536
537    #[rstest]
538    fn test_mul_div_accurate_without_phantom_overflow() {
539        // Calculate Q128 * 0.5 / 1.5 = Q128 / 3
540        let numerator_b = Q128 * U256::from(50) / U256::from(100); // 0.5
541        let denominator = Q128 * U256::from(150) / U256::from(100); // 1.5
542        let expected = Q128 / U256::from(3);
543
544        let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
545        assert_eq!(result, expected);
546    }
547
548    #[rstest]
549    fn test_mul_div_accurate_with_phantom_overflow() {
550        // Calculate Q128 * 35 * Q128 / (8 * Q128) = 35/8 * Q128 = 4.375 * Q128
551        let numerator_b = U256::from(35) * Q128;
552        let denominator = U256::from(8) * Q128;
553        let expected = U256::from(4375) * Q128 / U256::from(1000);
554
555        let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
556        assert_eq!(result, expected);
557    }
558
559    #[rstest]
560    fn test_mul_div_accurate_with_phantom_overflow_repeating_decimal() {
561        // Calculate Q128 * 1000 * Q128 / (3000 * Q128) = 1/3 * Q128
562        let numerator_b = U256::from(1000) * Q128;
563        let denominator = U256::from(3000) * Q128;
564        let expected = Q128 / U256::from(3);
565
566        let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
567        assert_eq!(result, expected);
568    }
569
570    #[rstest]
571    fn test_mul_div_basic_cases() {
572        // Simple case: 100 * 200 / 50 = 400
573        assert_eq!(
574            FullMath::mul_div(U256::from(100), U256::from(200), U256::from(50)).unwrap(),
575            U256::from(400)
576        );
577
578        // Test with 1: a * 1 / b = a / b
579        assert_eq!(
580            FullMath::mul_div(U256::from(1000), U256::from(1), U256::from(4)).unwrap(),
581            U256::from(250)
582        );
583
584        // Test division that results in 0 due to floor
585        assert_eq!(
586            FullMath::mul_div(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
587            U256::ZERO
588        );
589    }
590
591    // mul_div_rounding_up tests
592    #[rstest]
593    fn test_mul_div_rounding_up_reverts_denominator_zero() {
594        // Test that denominator 0 causes error
595        assert!(FullMath::mul_div_rounding_up(Q128, U256::from(5), U256::ZERO).is_err());
596
597        // Test with numerator overflow and denominator 0
598        assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::ZERO).is_err());
599    }
600
601    #[rstest]
602    fn test_mul_div_rounding_up_reverts_output_overflow() {
603        // Test output overflow: Q128 * Q128 / 1 would overflow
604        assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::from(1)).is_err());
605
606        // Test overflow with all max inputs minus 1 - this should pass since MAX/MAX-1 = ~1
607        // but since there's a remainder, rounding up would still fit in U256
608        // Let's test a case that actually overflows after rounding
609        assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::from(2)).is_err());
610    }
611
612    #[rstest]
613    fn test_mul_div_rounding_up_specific_overflow_cases() {
614        // Test specific overflow case from TypeScript: reverts if mulDiv overflows 256 bits after rounding up
615        let a = U256::from_str_radix("535006138814359", 10).unwrap();
616        let b = U256::from_str_radix(
617            "432862656469423142931042426214547535783388063929571229938474969",
618            10,
619        )
620        .unwrap();
621        let denominator = U256::from(2);
622
623        // First check if the base mul_div succeeds - if it does, this might not be an overflow case
624        let base_result = FullMath::mul_div(a, b, denominator);
625        if base_result.is_ok() {
626            // If base succeeds, check if there's a remainder and if result == MAX
627            let remainder = FullMath::mulmod(a, b, denominator);
628            if !remainder.is_zero() && base_result.unwrap() == U256::MAX {
629                assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
630            } else {
631                // This case doesn't actually overflow after rounding
632                assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_ok());
633            }
634        } else {
635            // Base mul_div fails, so rounding up should also fail
636            assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
637        }
638
639        // Test second specific overflow case - check if this actually overflows
640        let a2 = U256::from_str_radix(
641            "115792089237316195423570985008687907853269984659341747863450311749907997002549",
642            10,
643        )
644        .unwrap();
645        let b2 = U256::from_str_radix(
646            "115792089237316195423570985008687907853269984659341747863450311749907997002550",
647            10,
648        )
649        .unwrap();
650        let denominator2 = U256::from_str_radix(
651            "115792089237316195423570985008687907853269984653042931687443039491902864365164",
652            10,
653        )
654        .unwrap();
655
656        let base_result2 = FullMath::mul_div(a2, b2, denominator2);
657        if base_result2.is_ok() {
658            let remainder2 = FullMath::mulmod(a2, b2, denominator2);
659            if !remainder2.is_zero() && base_result2.unwrap() == U256::MAX {
660                assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_err());
661            } else {
662                // This case doesn't actually overflow after rounding
663                assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_ok());
664            }
665        } else {
666            // Base mul_div fails, so rounding up should also fail
667            assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_err());
668        }
669    }
670
671    #[rstest]
672    fn test_mul_div_rounding_up_all_max_inputs() {
673        // MAX * MAX / MAX = MAX (no rounding needed)
674        let result = FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::MAX).unwrap();
675        assert_eq!(result, U256::MAX);
676    }
677
678    #[rstest]
679    fn test_mul_div_rounding_up_accurate_without_phantom_overflow() {
680        // Calculate Q128 * 0.5 / 1.5 = Q128 / 3, but with rounding up
681        let numerator_b = Q128 * U256::from(50) / U256::from(100); // 0.5
682        let denominator = Q128 * U256::from(150) / U256::from(100); // 1.5
683        let expected = Q128 / U256::from(3) + U256::from(1); // Rounded up
684
685        let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
686        assert_eq!(result, expected);
687    }
688
689    #[rstest]
690    fn test_mul_div_rounding_up_accurate_with_phantom_overflow() {
691        // Calculate Q128 * 35 * Q128 / (8 * Q128) = 35/8 * Q128 = 4.375 * Q128
692        // This should be exact (no remainder), so no rounding up needed
693        let numerator_b = U256::from(35) * Q128;
694        let denominator = U256::from(8) * Q128;
695        let expected = U256::from(4375) * Q128 / U256::from(1000);
696
697        let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
698        assert_eq!(result, expected);
699    }
700
701    #[rstest]
702    fn test_mul_div_rounding_up_accurate_with_phantom_overflow_repeating_decimal() {
703        // Calculate Q128 * 1000 * Q128 / (3000 * Q128) = 1/3 * Q128, with rounding up
704        let numerator_b = U256::from(1000) * Q128;
705        let denominator = U256::from(3000) * Q128;
706        let expected = Q128 / U256::from(3) + U256::from(1); // Rounded up due to remainder
707
708        let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
709        assert_eq!(result, expected);
710    }
711
712    #[rstest]
713    fn test_mul_div_rounding_up_basic_cases() {
714        // Test exact division (no rounding needed)
715        assert_eq!(
716            FullMath::mul_div_rounding_up(U256::from(100), U256::from(200), U256::from(50))
717                .unwrap(),
718            U256::from(400)
719        );
720
721        // Test division with remainder (rounding up needed)
722        assert_eq!(
723            FullMath::mul_div_rounding_up(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
724            U256::from(1) // 0 rounded up to 1
725        );
726
727        // Test another rounding case: 7 * 3 / 4 = 21 / 4 = 5.25 -> 6
728        assert_eq!(
729            FullMath::mul_div_rounding_up(U256::from(7), U256::from(3), U256::from(4)).unwrap(),
730            U256::from(6)
731        );
732
733        // Test case with zero result and zero remainder
734        assert_eq!(
735            FullMath::mul_div_rounding_up(U256::ZERO, U256::from(100), U256::from(3)).unwrap(),
736            U256::ZERO
737        );
738    }
739
740    #[rstest]
741    fn test_mul_div_rounding_up_overflow_at_max() {
742        // Test that rounding up when result is already MAX causes overflow
743        // We need a case where mul_div returns MAX but there's a remainder
744        // This is tricky to construct, so we test the boundary condition
745        assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::from(2), U256::from(2)).is_ok());
746
747        // This should succeed: MAX * 1 / 1 = MAX (no remainder)
748        assert_eq!(
749            FullMath::mul_div_rounding_up(U256::MAX, U256::from(1), U256::from(1)).unwrap(),
750            U256::MAX
751        );
752    }
753
754    #[rstest]
755    fn test_truncate_to_u128_preserves_small_values() {
756        // Small value (fits in u128) should be preserved exactly
757        let value = U256::from(12345u128);
758        assert_eq!(FullMath::truncate_to_u128(value), 12345u128);
759
760        // u128::MAX should be preserved
761        let max_value = U256::from(u128::MAX);
762        assert_eq!(FullMath::truncate_to_u128(max_value), u128::MAX);
763    }
764
765    #[rstest]
766    fn test_truncate_to_u128_discards_upper_bits() {
767        // Value = u128::MAX + 1 (sets bit 128)
768        // Lower 128 bits = 0, so result should be 0
769        let value = U256::from(u128::MAX) + U256::from(1);
770        assert_eq!(FullMath::truncate_to_u128(value), 0);
771
772        // Value with both high and low bits set:
773        // High 128 bits = 0xFFFF...FFFF, Low 128 bits = 0x1234
774        let value = (U256::from(u128::MAX) << 128) | U256::from(0x1234u128);
775        assert_eq!(FullMath::truncate_to_u128(value), 0x1234u128);
776    }
777}