1use alloy_primitives::{I256, U160, U256};
17
18pub const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
19pub const Q96_U160: U160 = U160::from_limbs([0, 1 << 32, 0]);
20
21#[derive(Debug)]
25pub struct FullMath;
26
27impl FullMath {
28 pub fn mul_div(a: U256, b: U256, denominator: U256) -> anyhow::Result<U256> {
34 if denominator.is_zero() {
35 anyhow::bail!("Cannot divide by zero");
36 }
37
38 let (prod0, overflow) = a.overflowing_mul(b);
42
43 let prod1 = if overflow {
47 Self::mul_high(a, b)
50 } else {
51 U256::ZERO
52 };
53
54 if prod1.is_zero() {
56 return Ok(prod0 / denominator);
57 }
58
59 if denominator <= prod1 {
60 anyhow::bail!("Result would overflow 256 bits");
61 }
62
63 let remainder = Self::mulmod(a, b, denominator);
68
69 let (prod0, borrow) = prod0.overflowing_sub(remainder);
71 let prod1 = if borrow { prod1 - U256::from(1) } else { prod1 };
72
73 let twos = (!denominator).wrapping_add(U256::from(1)) & denominator;
76 let denominator = denominator / twos;
77
78 let prod0 = prod0 / twos;
80
81 let twos = Self::div_2_256_by(twos);
84 let prod0 = prod0 | (prod1 * twos);
85
86 let inv = Self::mod_inverse(denominator);
90
91 Ok(prod0.wrapping_mul(inv))
94 }
95
96 pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> anyhow::Result<U256> {
103 let result = Self::mul_div(a, b, denominator)?;
104
105 if !Self::mulmod(a, b, denominator).is_zero() {
107 if result == U256::MAX {
109 anyhow::bail!("Result would overflow 256 bits");
110 }
111 Ok(result + U256::from(1))
112 } else {
113 Ok(result)
114 }
115 }
116
117 pub fn div_rounding_up(a: U256, b: U256) -> anyhow::Result<U256> {
124 if b.is_zero() {
125 anyhow::bail!("Cannot divide by zero");
126 }
127
128 let quotient = a / b;
129 let remainder = a % b;
130
131 if remainder > U256::ZERO {
133 if quotient == U256::MAX {
135 anyhow::bail!("Result would overflow 256 bits");
136 }
137 Ok(quotient + U256::from(1))
138 } else {
139 Ok(quotient)
140 }
141 }
142
143 fn mod_inverse(denominator: U256) -> U256 {
145 let mut inv = (U256::from(3) * denominator) ^ U256::from(2);
148
149 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
155 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
157 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
159 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
161 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
163 inv = inv.wrapping_mul(U256::from(2).wrapping_sub(denominator.wrapping_mul(inv)));
165
166 inv
167 }
168
169 fn div_2_256_by(x: U256) -> U256 {
171 if x.is_zero() {
172 return U256::from(1);
173 }
174
175 let trailing_zeros = x.trailing_zeros();
178
179 if trailing_zeros >= 256 {
180 U256::from(1)
181 } else {
182 U256::from(1) << (256 - trailing_zeros)
183 }
184 }
185
186 fn mulmod(a: U256, b: U256, m: U256) -> U256 {
188 if m.is_zero() {
189 return U256::ZERO;
190 }
191
192 if a < U256::from(u128::MAX) && b < U256::from(u128::MAX) {
194 return (a * b) % m;
195 }
196
197 let (low, overflow) = a.overflowing_mul(b);
198 if !overflow {
199 return low % m;
200 }
201
202 Self::mulmod_slow(a, b, m)
205 }
206
207 fn mulmod_slow(mut a: U256, mut b: U256, m: U256) -> U256 {
209 let mut result = U256::ZERO;
210 a %= m;
211
212 while b > U256::ZERO {
213 if b & U256::from(1) == U256::from(1) {
214 result = (result + a) % m;
215 }
216 a = (a * U256::from(2)) % m;
217 b >>= 1
218 }
219
220 result
221 }
222
223 fn mul_high(a: U256, b: U256) -> U256 {
225 let a_low = a & U256::from(u128::MAX);
227 let a_high = a >> 128;
228 let b_low = b & U256::from(u128::MAX);
229 let b_high = b >> 128;
230
231 let ll = a_low * b_low;
233 let lh = a_low * b_high;
234 let hl = a_high * b_low;
235 let hh = a_high * b_high;
236
237 let mid_sum = lh + hl;
238 let mid_high = mid_sum >> 128;
239 let mid_low = mid_sum << 128;
240
241 let (_, carry) = ll.overflowing_add(mid_low);
243
244 hh + mid_high + if carry { U256::from(1) } else { U256::ZERO }
245 }
246
247 pub fn sqrt(x: U256) -> U256 {
249 if x.is_zero() {
250 return U256::ZERO;
251 }
252 if x == U256::from(1u128) {
253 return U256::from(1u128);
254 }
255
256 let mut z = x;
257 let mut y = (x + U256::from(1u128)) >> 1;
258
259 while y < z {
260 z = y;
261 y = (x / z + z) >> 1;
262 }
263
264 z
265 }
266
267 #[must_use]
273 pub fn truncate_to_u128(value: U256) -> u128 {
274 (value & U256::from(u128::MAX)).to::<u128>()
275 }
276
277 #[must_use]
283 pub fn truncate_to_u256(value: I256) -> U256 {
284 value.into_raw()
285 }
286
287 #[must_use]
295 pub fn truncate_to_i256(value: U256) -> I256 {
296 I256::from_raw(value)
297 }
298}
299
300#[cfg(test)]
305mod tests {
306 use rstest::*;
307
308 use super::*;
309
310 #[rstest]
311 fn test_mul_high_basic() {
312 assert_eq!(FullMath::mul_high(U256::ZERO, U256::ZERO), U256::ZERO);
314
315 assert_eq!(FullMath::mul_high(U256::from(1), U256::from(1)), U256::ZERO);
317
318 assert_eq!(FullMath::mul_high(U256::MAX, U256::from(1)), U256::ZERO);
320 }
321
322 #[rstest]
323 fn test_mul_high_simple_case() {
324 let result = FullMath::mul_high(Q128, Q128);
327 assert_eq!(result, U256::from(1));
328 }
329
330 #[rstest]
331 fn test_mul_high_asymmetric() {
332 let large = U256::MAX;
334 let small = U256::from(2);
335 let result = FullMath::mul_high(large, small);
336 assert_eq!(result, U256::from(1));
339
340 let a = U256::from(1u128) << 200;
342 let b = U256::from(1u128) << 100;
343 let expected = U256::from(1u128) << 44;
344 assert_eq!(FullMath::mul_high(a, b), expected);
345 }
346
347 #[rstest]
348 fn test_mul_high_known_values() {
349 let a = U256::from(u128::MAX); let b = U256::from(u128::MAX); let result = FullMath::mul_high(a, b);
353 assert_eq!(result, U256::ZERO);
356 }
357
358 #[rstest]
359 fn test_mul_high_carry_propagation() {
360 let a = U256::from_str_radix(
364 "ffffffffffffffffffffffffffffffff00000000000000000000000000000000",
365 16,
366 )
367 .unwrap();
368 let b = U256::from_str_radix(
369 "ffffffffffffffffffffffffffffffff00000000000000000000000000000000",
370 16,
371 )
372 .unwrap();
373
374 let result = FullMath::mul_high(a, b);
375
376 let expected = U256::from_str_radix(
380 "fffffffffffffffffffffffffffffffe00000000000000000000000000000001",
381 16,
382 )
383 .unwrap();
384 assert_eq!(result, expected);
385 }
386
387 #[rstest]
388 fn test_mul_high_symmetry() {
389 let a = U256::from_str_radix("123456789abcdef0123456789abcdef0", 16).unwrap();
391 let b = U256::from_str_radix("fedcba9876543210fedcba9876543210", 16).unwrap();
392
393 assert_eq!(FullMath::mul_high(a, b), FullMath::mul_high(b, a));
394 }
395
396 #[rstest]
397 fn test_mulmod_basic() {
398 assert_eq!(
400 FullMath::mulmod(U256::ZERO, U256::from(5), U256::from(3)),
401 U256::ZERO
402 );
403 assert_eq!(
404 FullMath::mulmod(U256::from(5), U256::ZERO, U256::from(3)),
405 U256::ZERO
406 );
407 assert_eq!(
408 FullMath::mulmod(U256::from(5), U256::from(3), U256::ZERO),
409 U256::ZERO
410 );
411
412 assert_eq!(
414 FullMath::mulmod(U256::from(5), U256::from(3), U256::from(7)),
415 U256::from(1)
416 );
417
418 assert_eq!(
420 FullMath::mulmod(U256::from(6), U256::from(2), U256::from(12)),
421 U256::ZERO
422 );
423 }
424
425 #[rstest]
426 fn test_mulmod_small_values() {
427 let a = U256::from(123456u64);
429 let b = U256::from(789012u64);
430 let m = U256::from(100000u64);
431
432 assert_eq!(FullMath::mulmod(a, b, m), U256::from(65472));
435 }
436
437 #[rstest]
438 fn test_mulmod_no_overflow() {
439 let a = U256::from(u64::MAX);
441 let b = U256::from(1000u64);
442 let m = U256::from(u64::MAX);
443
444 let result = FullMath::mulmod(a, b, m);
445 let expected = (U256::from(u64::MAX) * U256::from(1000)) % U256::from(u64::MAX);
446 assert_eq!(result, expected);
447 }
448
449 #[rstest]
450 fn test_mulmod_large_overflow() {
451 let a = U256::MAX;
453 let b = U256::MAX;
454 let m = U256::from(1000000007u64); let result = FullMath::mulmod(a, b, m);
457
458 let a_mod = a % m;
462 let b_mod = b % m;
463 let expected = (a_mod * b_mod) % m;
464 assert_eq!(result, expected);
465 }
466
467 #[rstest]
468 fn test_mulmod_symmetry() {
469 let a = U256::from_str_radix("123456789abcdef0", 16).unwrap();
471 let b = U256::from_str_radix("fedcba9876543210", 16).unwrap();
472 let m = U256::from(1000000007u64);
473
474 assert_eq!(FullMath::mulmod(a, b, m), FullMath::mulmod(b, a, m));
475 }
476
477 #[rstest]
478 fn test_mulmod_identity() {
479 let a = U256::from_str_radix("123456789abcdef0123456789abcdef0", 16).unwrap();
481 let m = U256::from(1000000007u64);
482
483 assert_eq!(FullMath::mulmod(a, U256::from(1), m), a % m);
484 assert_eq!(FullMath::mulmod(U256::from(1), a, m), a % m);
485 }
486
487 #[rstest]
488 fn test_mulmod_powers_of_two() {
489 let a = U256::from(1) << 100; let b = U256::from(1) << 50; let m = U256::from(1) << 60; assert_eq!(FullMath::mulmod(a, b, m), U256::ZERO);
497
498 let a2 = U256::from(1) << 30; let b2 = U256::from(1) << 20; let m2 = U256::from(1) << 60; let expected = U256::from(1) << 50;
505 assert_eq!(FullMath::mulmod(a2, b2, m2), expected);
506 }
507
508 #[rstest]
509 fn test_mul_div_reverts_denominator_zero() {
510 assert!(FullMath::mul_div(Q128, U256::from(5), U256::ZERO).is_err());
512
513 assert!(FullMath::mul_div(Q128, Q128, U256::ZERO).is_err());
515 }
516
517 #[rstest]
518 fn test_mul_div_reverts_output_overflow() {
519 assert!(FullMath::mul_div(Q128, Q128, U256::from(1)).is_err());
521
522 assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(1)).is_err());
525
526 assert!(FullMath::mul_div(U256::MAX, U256::MAX, U256::from(2)).is_err());
528 }
529
530 #[rstest]
531 fn test_mul_div_all_max_inputs() {
532 let result = FullMath::mul_div(U256::MAX, U256::MAX, U256::MAX).unwrap();
534 assert_eq!(result, U256::MAX);
535 }
536
537 #[rstest]
538 fn test_mul_div_accurate_without_phantom_overflow() {
539 let numerator_b = Q128 * U256::from(50) / U256::from(100); let denominator = Q128 * U256::from(150) / U256::from(100); let expected = Q128 / U256::from(3);
543
544 let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
545 assert_eq!(result, expected);
546 }
547
548 #[rstest]
549 fn test_mul_div_accurate_with_phantom_overflow() {
550 let numerator_b = U256::from(35) * Q128;
552 let denominator = U256::from(8) * Q128;
553 let expected = U256::from(4375) * Q128 / U256::from(1000);
554
555 let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
556 assert_eq!(result, expected);
557 }
558
559 #[rstest]
560 fn test_mul_div_accurate_with_phantom_overflow_repeating_decimal() {
561 let numerator_b = U256::from(1000) * Q128;
563 let denominator = U256::from(3000) * Q128;
564 let expected = Q128 / U256::from(3);
565
566 let result = FullMath::mul_div(Q128, numerator_b, denominator).unwrap();
567 assert_eq!(result, expected);
568 }
569
570 #[rstest]
571 fn test_mul_div_basic_cases() {
572 assert_eq!(
574 FullMath::mul_div(U256::from(100), U256::from(200), U256::from(50)).unwrap(),
575 U256::from(400)
576 );
577
578 assert_eq!(
580 FullMath::mul_div(U256::from(1000), U256::from(1), U256::from(4)).unwrap(),
581 U256::from(250)
582 );
583
584 assert_eq!(
586 FullMath::mul_div(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
587 U256::ZERO
588 );
589 }
590
591 #[rstest]
593 fn test_mul_div_rounding_up_reverts_denominator_zero() {
594 assert!(FullMath::mul_div_rounding_up(Q128, U256::from(5), U256::ZERO).is_err());
596
597 assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::ZERO).is_err());
599 }
600
601 #[rstest]
602 fn test_mul_div_rounding_up_reverts_output_overflow() {
603 assert!(FullMath::mul_div_rounding_up(Q128, Q128, U256::from(1)).is_err());
605
606 assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::from(2)).is_err());
610 }
611
612 #[rstest]
613 fn test_mul_div_rounding_up_specific_overflow_cases() {
614 let a = U256::from_str_radix("535006138814359", 10).unwrap();
616 let b = U256::from_str_radix(
617 "432862656469423142931042426214547535783388063929571229938474969",
618 10,
619 )
620 .unwrap();
621 let denominator = U256::from(2);
622
623 let base_result = FullMath::mul_div(a, b, denominator);
625 if base_result.is_ok() {
626 let remainder = FullMath::mulmod(a, b, denominator);
628 if !remainder.is_zero() && base_result.unwrap() == U256::MAX {
629 assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
630 } else {
631 assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_ok());
633 }
634 } else {
635 assert!(FullMath::mul_div_rounding_up(a, b, denominator).is_err());
637 }
638
639 let a2 = U256::from_str_radix(
641 "115792089237316195423570985008687907853269984659341747863450311749907997002549",
642 10,
643 )
644 .unwrap();
645 let b2 = U256::from_str_radix(
646 "115792089237316195423570985008687907853269984659341747863450311749907997002550",
647 10,
648 )
649 .unwrap();
650 let denominator2 = U256::from_str_radix(
651 "115792089237316195423570985008687907853269984653042931687443039491902864365164",
652 10,
653 )
654 .unwrap();
655
656 let base_result2 = FullMath::mul_div(a2, b2, denominator2);
657 if base_result2.is_ok() {
658 let remainder2 = FullMath::mulmod(a2, b2, denominator2);
659 if !remainder2.is_zero() && base_result2.unwrap() == U256::MAX {
660 assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_err());
661 } else {
662 assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_ok());
664 }
665 } else {
666 assert!(FullMath::mul_div_rounding_up(a2, b2, denominator2).is_err());
668 }
669 }
670
671 #[rstest]
672 fn test_mul_div_rounding_up_all_max_inputs() {
673 let result = FullMath::mul_div_rounding_up(U256::MAX, U256::MAX, U256::MAX).unwrap();
675 assert_eq!(result, U256::MAX);
676 }
677
678 #[rstest]
679 fn test_mul_div_rounding_up_accurate_without_phantom_overflow() {
680 let numerator_b = Q128 * U256::from(50) / U256::from(100); let denominator = Q128 * U256::from(150) / U256::from(100); let expected = Q128 / U256::from(3) + U256::from(1); let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
686 assert_eq!(result, expected);
687 }
688
689 #[rstest]
690 fn test_mul_div_rounding_up_accurate_with_phantom_overflow() {
691 let numerator_b = U256::from(35) * Q128;
694 let denominator = U256::from(8) * Q128;
695 let expected = U256::from(4375) * Q128 / U256::from(1000);
696
697 let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
698 assert_eq!(result, expected);
699 }
700
701 #[rstest]
702 fn test_mul_div_rounding_up_accurate_with_phantom_overflow_repeating_decimal() {
703 let numerator_b = U256::from(1000) * Q128;
705 let denominator = U256::from(3000) * Q128;
706 let expected = Q128 / U256::from(3) + U256::from(1); let result = FullMath::mul_div_rounding_up(Q128, numerator_b, denominator).unwrap();
709 assert_eq!(result, expected);
710 }
711
712 #[rstest]
713 fn test_mul_div_rounding_up_basic_cases() {
714 assert_eq!(
716 FullMath::mul_div_rounding_up(U256::from(100), U256::from(200), U256::from(50))
717 .unwrap(),
718 U256::from(400)
719 );
720
721 assert_eq!(
723 FullMath::mul_div_rounding_up(U256::from(1), U256::from(1), U256::from(3)).unwrap(),
724 U256::from(1) );
726
727 assert_eq!(
729 FullMath::mul_div_rounding_up(U256::from(7), U256::from(3), U256::from(4)).unwrap(),
730 U256::from(6)
731 );
732
733 assert_eq!(
735 FullMath::mul_div_rounding_up(U256::ZERO, U256::from(100), U256::from(3)).unwrap(),
736 U256::ZERO
737 );
738 }
739
740 #[rstest]
741 fn test_mul_div_rounding_up_overflow_at_max() {
742 assert!(FullMath::mul_div_rounding_up(U256::MAX, U256::from(2), U256::from(2)).is_ok());
746
747 assert_eq!(
749 FullMath::mul_div_rounding_up(U256::MAX, U256::from(1), U256::from(1)).unwrap(),
750 U256::MAX
751 );
752 }
753
754 #[rstest]
755 fn test_truncate_to_u128_preserves_small_values() {
756 let value = U256::from(12345u128);
758 assert_eq!(FullMath::truncate_to_u128(value), 12345u128);
759
760 let max_value = U256::from(u128::MAX);
762 assert_eq!(FullMath::truncate_to_u128(max_value), u128::MAX);
763 }
764
765 #[rstest]
766 fn test_truncate_to_u128_discards_upper_bits() {
767 let value = U256::from(u128::MAX) + U256::from(1);
770 assert_eq!(FullMath::truncate_to_u128(value), 0);
771
772 let value = (U256::from(u128::MAX) << 128) | U256::from(0x1234u128);
775 assert_eq!(FullMath::truncate_to_u128(value), 0x1234u128);
776 }
777}