diff --git a/benches/int.rs b/benches/int.rs index 466eb6b7..d9a6c728 100644 --- a/benches/int.rs +++ b/benches/int.rs @@ -1,7 +1,6 @@ use std::ops::Div; use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; -use num_traits::WrappingSub; use rand_core::OsRng; use crypto_bigint::{NonZero, Random, I1024, I128, I2048, I256, I4096, I512}; diff --git a/benches/uint.rs b/benches/uint.rs index 3bb2a961..4a6de5ea 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1,7 +1,11 @@ -use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use criterion::measurement::WallTime; +use criterion::{ + black_box, criterion_group, criterion_main, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, +}; +use crypto_bigint::modular::SafeGcdInverter; use crypto_bigint::{ - Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U2048, U256, - U4096, U512, + Int, Limb, NonZero, Odd, PrecomputeInverter, Random, RandomBits, RandomMod, Reciprocal, Uint, + U1024, U128, U16384, U192, U2048, U256, U320, U384, U4096, U448, U512, U64, U8192, }; use rand_chacha::ChaCha8Rng; use rand_core::{OsRng, RngCore, SeedableRng}; @@ -302,32 +306,121 @@ fn bench_division(c: &mut Criterion) { group.finish(); } -fn bench_gcd(c: &mut Criterion) { - let mut group = c.benchmark_group("greatest common divisor"); +fn gcd_bench( + g: &mut BenchmarkGroup, + _x: Uint, +) where + Odd>: PrecomputeInverter>, +{ + g.bench_function(BenchmarkId::new("gcd", LIMBS), |b| { + b.iter_batched( + || { + let f = Uint::::random(&mut OsRng); + let g = Uint::::random(&mut OsRng); + (f, g) + }, + |(f, g)| black_box(Uint::gcd(&f, &g)), + BatchSize::SmallInput, + ) + }); + g.bench_function(BenchmarkId::new("bingcd", LIMBS), |b| { + b.iter_batched( + || { + let f = Uint::::random(&mut OsRng); + let g = Uint::::random(&mut OsRng); + (f, g) + }, + |(f, g)| black_box(Uint::bingcd(&f, &g)), + BatchSize::SmallInput, + ) + }); - group.bench_function("gcd, U256", |b| { + g.bench_function(BenchmarkId::new("bingcd_small", LIMBS), |b| { + b.iter_batched( + || { + let f = Uint::::random(&mut OsRng) + .bitor(&Uint::ONE) + .to_odd() + .unwrap(); + let g = Uint::::random(&mut OsRng); + (f, g) + }, + |(f, g)| black_box(f.classic_bingcd(&g)), + BatchSize::SmallInput, + ) + }); + g.bench_function(BenchmarkId::new("bingcd_large", LIMBS), |b| { b.iter_batched( || { - let f = U256::random(&mut OsRng); - let g = U256::random(&mut OsRng); + let f = Uint::::random(&mut OsRng) + .bitor(&Uint::ONE) + .to_odd() + .unwrap(); + let g = Uint::::random(&mut OsRng); (f, g) }, - |(f, g)| black_box(f.gcd(&g)), + |(f, g)| black_box(f.optimized_bingcd(&g)), BatchSize::SmallInput, ) }); +} + +fn bench_gcd(c: &mut Criterion) { + let mut group = c.benchmark_group("greatest common divisor"); + + gcd_bench(&mut group, U64::ZERO); + gcd_bench(&mut group, U128::ZERO); + gcd_bench(&mut group, U192::ZERO); + gcd_bench(&mut group, U256::ZERO); + gcd_bench(&mut group, U320::ZERO); + gcd_bench(&mut group, U384::ZERO); + gcd_bench(&mut group, U448::ZERO); + gcd_bench(&mut group, U512::ZERO); + gcd_bench(&mut group, U1024::ZERO); + gcd_bench(&mut group, U2048::ZERO); + gcd_bench(&mut group, U4096::ZERO); + gcd_bench(&mut group, U8192::ZERO); + gcd_bench(&mut group, U16384::ZERO); + + group.finish(); +} - group.bench_function("gcd_vartime, U256", |b| { +fn xgcd_bench( + g: &mut BenchmarkGroup, + _x: Uint, +) where + Odd>: PrecomputeInverter>, +{ + g.bench_function(BenchmarkId::new("binxgcd", LIMBS), |b| { b.iter_batched( || { - let f = Odd::::random(&mut OsRng); - let g = U256::random(&mut OsRng); + let modulus = Int::MIN.as_uint().wrapping_add(&Uint::ONE).to_nz().unwrap(); + let f = Uint::::random_mod(&mut OsRng, &modulus).as_int(); + let g = Uint::::random_mod(&mut OsRng, &modulus).as_int(); (f, g) }, - |(f, g)| black_box(f.gcd_vartime(&g)), + |(f, g)| black_box(f.binxgcd(&g)), BatchSize::SmallInput, ) }); +} + +fn bench_xgcd(c: &mut Criterion) { + let mut group = c.benchmark_group("greatest common divisor"); + + xgcd_bench(&mut group, U64::ZERO); + xgcd_bench(&mut group, U128::ZERO); + xgcd_bench(&mut group, U192::ZERO); + xgcd_bench(&mut group, U256::ZERO); + xgcd_bench(&mut group, U320::ZERO); + xgcd_bench(&mut group, U384::ZERO); + xgcd_bench(&mut group, U448::ZERO); + xgcd_bench(&mut group, U512::ZERO); + xgcd_bench(&mut group, U1024::ZERO); + xgcd_bench(&mut group, U2048::ZERO); + xgcd_bench(&mut group, U4096::ZERO); + xgcd_bench(&mut group, U8192::ZERO); + xgcd_bench(&mut group, U16384::ZERO); group.finish(); } @@ -491,6 +584,7 @@ criterion_group!( bench_mul, bench_division, bench_gcd, + bench_xgcd, bench_shl, bench_shr, bench_inv_mod, diff --git a/src/const_choice.rs b/src/const_choice.rs index 5e43e38c..b612c307 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -413,6 +413,20 @@ impl ConstCtOption<(Uint, Uint)> { } } +impl ConstCtOption<(Uint, ConstChoice)> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> (Uint, ConstChoice) { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption>> { /// Returns the contained value, consuming the `self` value. /// @@ -461,6 +475,34 @@ impl ConstCtOption> { } } +impl ConstCtOption>> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> NonZero> { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + +impl ConstCtOption>> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> Odd> { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption> { /// Returns the contained value, consuming the `self` value. /// diff --git a/src/int.rs b/src/int.rs index 3b3d60ee..3e39692c 100644 --- a/src/int.rs +++ b/src/int.rs @@ -12,6 +12,7 @@ use crate::Encoding; use crate::{Bounded, ConstChoice, ConstCtOption, Constants, Limb, NonZero, Odd, Uint, Word}; mod add; +mod bingcd; mod bit_and; mod bit_not; mod bit_or; @@ -119,7 +120,7 @@ impl Int { } /// Borrow the limbs of this [`Int`] mutably. - pub fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { + pub const fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { self.0.as_limbs_mut() } diff --git a/src/int/bingcd.rs b/src/int/bingcd.rs new file mode 100644 index 00000000..6fb909e8 --- /dev/null +++ b/src/int/bingcd.rs @@ -0,0 +1,419 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +use crate::modular::bingcd::tools::const_min; +use crate::{ConstChoice, Int, NonZero, Odd, Uint}; + +#[derive(Debug)] +pub struct BinXgcdOutput { + gcd: Uint, + x: Int, + y: Int, + lhs_on_gcd: Int, + rhs_on_gcd: Int, +} + +impl BinXgcdOutput { + /// Return the quotients `lhs.gcd` and `rhs/gcd`. + pub const fn quotients(&self) -> (Int, Int) { + (self.lhs_on_gcd, self.rhs_on_gcd) + } + + /// Provide mutable access to the quotients `lhs.gcd` and `rhs/gcd`. + pub const fn quotients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.lhs_on_gcd, &mut self.rhs_on_gcd) + } + + /// Return the Bézout coefficients `x` and `y` s.t. `lhs * x + rhs * y = gcd`. + pub const fn bezout_coefficients(&self) -> (Int, Int) { + (self.x, self.y) + } + + /// Provide mutable access to the Bézout coefficients. + pub const fn bezout_coefficients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.x, &mut self.y) + } + + /// Obtain a pair of minimal Bézout coefficients. + pub const fn minimal_bezout_coefficients(&self) -> (Int, Int) { + // Attempt to reduce x and y mod rhs_on_gcd and lhs_on_gcd, respectively. + let rhs_on_gcd_is_zero = self.rhs_on_gcd.is_nonzero().not(); + let lhs_on_gcd_is_zero = self.lhs_on_gcd.is_nonzero().not(); + let nz_rhs_on_gcd = Int::select(&self.rhs_on_gcd, &Int::ONE, rhs_on_gcd_is_zero); + let nz_lhs_on_gcd = Int::select(&self.lhs_on_gcd, &Int::ONE, lhs_on_gcd_is_zero); + let mut minimal_x = self.x.rem(&nz_rhs_on_gcd.to_nz().expect("is nz")); + let mut minimal_y = self.y.rem(&nz_lhs_on_gcd.to_nz().expect("is nz")); + + // This trick only needs to be applied whenever lhs/rhs > 1. + minimal_x = Int::select( + &self.x, + &minimal_x, + Uint::gt(&self.rhs_on_gcd.abs(), &Uint::ONE), + ); + minimal_y = Int::select( + &self.y, + &minimal_y, + Uint::gt(&self.lhs_on_gcd.abs(), &Uint::ONE), + ); + + (minimal_x, minimal_y) + } +} + +impl Int { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub const fn bingcd(&self, rhs: &Self) -> Uint { + self.abs().bingcd(&rhs.abs()) + } + + /// Executes the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub const fn binxgcd(&self, rhs: &Self) -> BinXgcdOutput { + // Make sure `self` and `rhs` are nonzero. + let self_is_zero = self.is_nonzero().not(); + let self_nz = Int::select(self, &Int::ONE, self_is_zero) + .to_nz() + .expect("self is non zero by construction"); + let rhs_is_zero = rhs.is_nonzero().not(); + let rhs_nz = Int::select(rhs, &Int::ONE, rhs_is_zero) + .to_nz() + .expect("rhs is non zero by construction"); + + let mut output = self_nz.binxgcd(&rhs_nz); + + // Correct the gcd in case self and/or rhs was zero + let gcd = &mut output.gcd; + *gcd = Uint::select(gcd, &rhs.abs(), self_is_zero); + *gcd = Uint::select(gcd, &self.abs(), rhs_is_zero); + + // Correct the Bézout coefficients in case self and/or rhs was zero. + let (x, y) = output.bezout_coefficients_as_mut(); + let signum_self = Int::new_from_abs_sign(Uint::ONE, self.is_negative()).expect("+/- 1"); + let signum_rhs = Int::new_from_abs_sign(Uint::ONE, rhs.is_negative()).expect("+/- 1"); + *x = Int::select(x, &Int::ZERO, self_is_zero); + *y = Int::select(y, &signum_rhs, self_is_zero); + *x = Int::select(x, &signum_self, rhs_is_zero); + *y = Int::select(y, &Int::ZERO, rhs_is_zero); + + // Correct the quotients in case self and/or rhs was zero. + let (lhs_on_gcd, rhs_on_gcd) = output.quotients_as_mut(); + *lhs_on_gcd = Int::select(lhs_on_gcd, &signum_self, rhs_is_zero); + *lhs_on_gcd = Int::select(lhs_on_gcd, &Int::ZERO, self_is_zero); + *rhs_on_gcd = Int::select(rhs_on_gcd, &signum_rhs, self_is_zero); + *rhs_on_gcd = Int::select(rhs_on_gcd, &Int::ZERO, rhs_is_zero); + + output + } +} + +impl NonZero> { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub const fn bingcd(&self, rhs: &Self) -> NonZero> { + self.abs().bingcd(&rhs.as_ref().abs()) + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub const fn binxgcd(&self, rhs: &Self) -> BinXgcdOutput { + let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref()); + + // Leverage the property that gcd(2^k * a, 2^k *b) = 2^k * gcd(a, b) + let i = lhs.0.trailing_zeros(); + let j = rhs.0.trailing_zeros(); + let k = const_min(i, j); + lhs = lhs.shr(k); + rhs = rhs.shr(k); + + // Note: at this point, either lhs or rhs is odd (or both). + // Swap to make sure lhs is odd. + let swap = ConstChoice::from_u32_lt(j, i); + Int::conditional_swap(&mut lhs, &mut rhs, swap); + let lhs = lhs.to_odd().expect("odd by construction"); + + let rhs = rhs.to_nz().expect("non-zero by construction"); + let mut output = lhs.binxgcd(&rhs); + + // Account for the parameter swap + let (x, y) = output.bezout_coefficients_as_mut(); + Int::conditional_swap(x, y, swap); + let (lhs_on_gcd, rhs_on_gcd) = output.quotients_as_mut(); + Int::conditional_swap(lhs_on_gcd, rhs_on_gcd, swap); + + // Reintroduce the factor 2^k to the gcd. + output.gcd = output.gcd.shl(k); + + output + } +} + +impl Odd> { + /// Compute the gcd of `self` and `rhs` leveraging the Binary GCD algorithm. + pub const fn bingcd(&self, rhs: &Self) -> Odd> { + self.abs().bingcd(&rhs.as_ref().abs()) + } + + /// Execute the Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + pub const fn binxgcd(&self, rhs: &NonZero>) -> BinXgcdOutput { + let (abs_lhs, sgn_lhs) = self.abs_sign(); + let (abs_rhs, sgn_rhs) = rhs.abs_sign(); + + let output = abs_lhs.binxgcd_nz(&abs_rhs); + + let (mut x, mut y) = output.bezout_coefficients(); + x = x.wrapping_neg_if(sgn_lhs); + y = y.wrapping_neg_if(sgn_rhs); + + let (abs_lhs_on_gcd, abs_rhs_on_gcd) = output.quotients(); + let lhs_on_gcd = Int::new_from_abs_sign(abs_lhs_on_gcd, sgn_lhs).expect("no overflow"); + let rhs_on_gcd = Int::new_from_abs_sign(abs_rhs_on_gcd, sgn_rhs).expect("no overflow"); + + BinXgcdOutput { + gcd: *output.gcd.as_ref(), + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } +} + +#[cfg(test)] +mod test { + use crate::int::bingcd::BinXgcdOutput; + use crate::{ConcatMixed, Int, Uint}; + use num_traits::Zero; + + fn binxgcd_test( + lhs: Int, + rhs: Int, + output: BinXgcdOutput, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let gcd = lhs.bingcd(&rhs); + assert_eq!(gcd, output.gcd); + + // Test quotients + let (lhs_on_gcd, rhs_on_gcd) = output.quotients(); + if gcd.is_zero() { + assert_eq!(lhs_on_gcd, Int::ZERO); + assert_eq!(rhs_on_gcd, Int::ZERO); + } else { + assert_eq!(lhs_on_gcd, lhs.div_uint(&gcd.to_nz().unwrap())); + assert_eq!(rhs_on_gcd, rhs.div_uint(&gcd.to_nz().unwrap())); + } + + // Test the Bezout coefficients + let (x, y) = output.bezout_coefficients(); + assert_eq!( + x.widening_mul(&lhs).wrapping_add(&y.widening_mul(&rhs)), + gcd.resize().as_int() + ); + + // Test the minimal Bezout coefficients on minimality + let (x, y) = output.minimal_bezout_coefficients(); + assert!(x.abs() <= rhs_on_gcd.abs() || rhs_on_gcd.is_zero()); + assert!(y.abs() <= lhs_on_gcd.abs() || lhs_on_gcd.is_zero()); + + // Test the minimal Bezout coefficients for correctness + assert_eq!( + x.widening_mul(&lhs).wrapping_add(&y.widening_mul(&rhs)), + gcd.resize().as_int() + ); + } + + mod test_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Gcd, Int, Random, Uint, U1024, U128, U192, U2048, U256, U384, U4096, U512, + U64, U768, U8192, + }; + use rand_core::OsRng; + + fn int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + binxgcd_test(lhs, rhs, lhs.binxgcd(&rhs)) + } + + fn int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + int_binxgcd_test(Int::MIN, Int::MIN); + int_binxgcd_test(Int::MIN, Int::MINUS_ONE); + int_binxgcd_test(Int::MIN, Int::ZERO); + int_binxgcd_test(Int::MIN, Int::ONE); + int_binxgcd_test(Int::MIN, Int::MAX); + int_binxgcd_test(Int::ONE, Int::MIN); + int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + int_binxgcd_test(Int::ONE, Int::ZERO); + int_binxgcd_test(Int::ONE, Int::ONE); + int_binxgcd_test(Int::ONE, Int::MAX); + int_binxgcd_test(Int::ZERO, Int::MIN); + int_binxgcd_test(Int::ZERO, Int::MINUS_ONE); + int_binxgcd_test(Int::ZERO, Int::ZERO); + int_binxgcd_test(Int::ZERO, Int::ONE); + int_binxgcd_test(Int::ZERO, Int::MAX); + int_binxgcd_test(Int::ONE, Int::MIN); + int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + int_binxgcd_test(Int::ONE, Int::ZERO); + int_binxgcd_test(Int::ONE, Int::ONE); + int_binxgcd_test(Int::ONE, Int::MAX); + int_binxgcd_test(Int::MAX, Int::MIN); + int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + int_binxgcd_test(Int::MAX, Int::ZERO); + int_binxgcd_test(Int::MAX, Int::ONE); + int_binxgcd_test(Int::MAX, Int::MAX); + + for _ in 0..100 { + let x = Int::random(&mut OsRng); + let y = Int::random(&mut OsRng); + int_binxgcd_test(x, y); + } + } + + #[test] + fn test_int_binxgcd() { + int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_nonzero_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Gcd, Int, RandomMod, Uint, U1024, U128, U192, U2048, U256, U384, U4096, + U512, U64, U768, U8192, + }; + use rand_core::OsRng; + + fn nz_int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + let output = lhs.to_nz().unwrap().binxgcd(&rhs.to_nz().unwrap()); + binxgcd_test(lhs, rhs, output); + } + + fn nz_int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + Int: Gcd>, + { + nz_int_binxgcd_test(Int::MIN, Int::MIN); + nz_int_binxgcd_test(Int::MIN, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MIN, Int::ONE); + nz_int_binxgcd_test(Int::MIN, Int::MAX); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MIN); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::ONE); + nz_int_binxgcd_test(Int::MINUS_ONE, Int::MAX); + nz_int_binxgcd_test(Int::ONE, Int::MIN); + nz_int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::ONE, Int::ONE); + nz_int_binxgcd_test(Int::ONE, Int::MAX); + nz_int_binxgcd_test(Int::MAX, Int::MIN); + nz_int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + nz_int_binxgcd_test(Int::MAX, Int::ONE); + nz_int_binxgcd_test(Int::MAX, Int::MAX); + + let bound = Int::MIN.abs().to_nz().unwrap(); + for _ in 0..100 { + let x = Uint::random_mod(&mut OsRng, &bound).as_int(); + let y = Uint::random_mod(&mut OsRng, &bound).as_int(); + nz_int_binxgcd_test(x, y); + } + } + + #[test] + fn test_nz_int_binxgcd() { + nz_int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + nz_int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + nz_int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + nz_int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + nz_int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + nz_int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + nz_int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + nz_int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + nz_int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_odd_int_binxgcd { + use crate::int::bingcd::test::binxgcd_test; + use crate::{ + ConcatMixed, Int, Random, Uint, U1024, U128, U192, U2048, U256, U384, U4096, U512, U64, + U768, U8192, + }; + use rand_core::OsRng; + + fn odd_int_binxgcd_test( + lhs: Int, + rhs: Int, + ) where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let output = lhs.to_odd().unwrap().binxgcd(&rhs.to_nz().unwrap()); + binxgcd_test(lhs, rhs, output); + } + + fn odd_int_binxgcd_tests() + where + Uint: ConcatMixed, MixedOutput = Uint>, + { + let neg_max = Int::MAX.wrapping_neg(); + odd_int_binxgcd_test(neg_max, neg_max); + odd_int_binxgcd_test(neg_max, Int::MINUS_ONE); + odd_int_binxgcd_test(neg_max, Int::ONE); + odd_int_binxgcd_test(neg_max, Int::MAX); + odd_int_binxgcd_test(Int::ONE, neg_max); + odd_int_binxgcd_test(Int::ONE, Int::MINUS_ONE); + odd_int_binxgcd_test(Int::ONE, Int::ONE); + odd_int_binxgcd_test(Int::ONE, Int::MAX); + odd_int_binxgcd_test(Int::MAX, neg_max); + odd_int_binxgcd_test(Int::MAX, Int::MINUS_ONE); + odd_int_binxgcd_test(Int::MAX, Int::ONE); + odd_int_binxgcd_test(Int::MAX, Int::MAX); + + for _ in 0..100 { + let x = Int::::random(&mut OsRng).bitor(&Int::ONE); + let y = Int::::random(&mut OsRng); + odd_int_binxgcd_test(x, y); + } + } + + #[test] + fn test_odd_int_binxgcd() { + odd_int_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + odd_int_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + odd_int_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + odd_int_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + odd_int_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + odd_int_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + odd_int_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + odd_int_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + odd_int_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } +} diff --git a/src/int/cmp.rs b/src/int/cmp.rs index 63a8b220..b9705790 100644 --- a/src/int/cmp.rs +++ b/src/int/cmp.rs @@ -15,6 +15,12 @@ impl Int { Self(Uint::select(&a.0, &b.0, c)) } + /// Swap `a` and `b` if `c` is truthy, otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_swap(a: &mut Self, b: &mut Self, c: ConstChoice) { + Uint::conditional_swap(&mut a.0, &mut b.0, c); + } + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. #[inline] pub(crate) const fn is_nonzero(&self) -> ConstChoice { diff --git a/src/int/gcd.rs b/src/int/gcd.rs index 0df12457..a7c18ea1 100644 --- a/src/int/gcd.rs +++ b/src/int/gcd.rs @@ -37,7 +37,7 @@ where #[cfg(test)] mod tests { - use crate::{Gcd, I256, U256}; + use crate::{Gcd, I256, I64, U256}; #[test] fn gcd_always_positive() { @@ -60,4 +60,10 @@ mod tests { assert_eq!(U256::from(61u32), f.gcd(&g)); assert_eq!(U256::from(61u32), f.wrapping_neg().gcd(&g)); } + + #[test] + fn gcd() { + assert_eq!(I64::MIN.gcd(&I64::ZERO), I64::MIN.abs()); + assert_eq!(I64::ZERO.gcd(&I64::MIN), I64::MIN.abs()); + } } diff --git a/src/int/mul.rs b/src/int/mul.rs index 315564c4..03877a46 100644 --- a/src/int/mul.rs +++ b/src/int/mul.rs @@ -4,7 +4,7 @@ use core::ops::{Mul, MulAssign}; use subtle::CtOption; -use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint, Zero}; +use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint}; impl Int { /// Compute "wide" multiplication as a 3-tuple `(lo, hi, negate)`. @@ -49,6 +49,23 @@ impl Int { // always fits Int::from_bits(product_abs.wrapping_neg_if(product_sign)) } + + /// Multiply `self` with `rhs`, returning a [ConstCtOption] that `is_some` only if the result + /// fits in an `Int`. + pub(crate) const fn const_checked_mul( + &self, + rhs: &Int, + ) -> ConstCtOption> { + let (lo, hi, is_negative) = self.split_mul(rhs); + Self::new_from_abs_sign(lo, is_negative).and_choice(hi.is_nonzero().not()) + } + + /// Multiply `self` with `rhs`, returning a [ConstCtOption] that `is_some` only if the result + /// fits in an `Int`. + pub const fn wrapping_mul(&self, rhs: &Int) -> Int { + let (lo, _, is_negative) = self.split_mul(rhs); + Self(lo.wrapping_neg_if(is_negative)) + } } /// Squaring operations. @@ -80,9 +97,7 @@ impl Int { impl CheckedMul> for Int { #[inline] fn checked_mul(&self, rhs: &Int) -> CtOption { - let (lo, hi, is_negative) = self.split_mul(rhs); - let val = Self::new_from_abs_sign(lo, is_negative); - CtOption::from(val).and_then(|int| CtOption::new(int, hi.is_zero())) + Self::const_checked_mul(self, rhs).into() } } @@ -114,7 +129,7 @@ impl Mul<&Int> for &Int; fn mul(self, rhs: &Int) -> Self::Output { - self.checked_mul(rhs) + self.const_checked_mul(rhs) .expect("attempted to multiply with overflow") } } diff --git a/src/int/resize.rs b/src/int/resize.rs index 4aed2899..1dc9cb4d 100644 --- a/src/int/resize.rs +++ b/src/int/resize.rs @@ -18,7 +18,6 @@ impl Int { #[cfg(test)] mod tests { - use num_traits::WrappingSub; use crate::{I128, I256}; diff --git a/src/int/sign.rs b/src/int/sign.rs index e8bd3ed0..11b9bde5 100644 --- a/src/int/sign.rs +++ b/src/int/sign.rs @@ -1,4 +1,4 @@ -use crate::{ConstChoice, ConstCtOption, Int, Uint, Word}; +use crate::{ConstChoice, ConstCtOption, Int, Odd, Uint, Word}; use num_traits::ConstZero; impl Int { @@ -49,6 +49,20 @@ impl Int { } } +impl Odd> { + /// The sign and magnitude of this [`Odd`]. + pub const fn abs_sign(&self) -> (Odd>, ConstChoice) { + let (abs, sgn) = Int::abs_sign(self.as_ref()); + let odd_abs = abs.to_odd().expect("abs value of an odd number is odd"); + (odd_abs, sgn) + } + + /// The magnitude of this [`Odd`]. + pub const fn abs(&self) -> Odd> { + self.abs_sign().0 + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/int/sub.rs b/src/int/sub.rs index bb94f642..b392dc0f 100644 --- a/src/int/sub.rs +++ b/src/int/sub.rs @@ -3,12 +3,14 @@ use core::ops::{Sub, SubAssign}; use num_traits::WrappingSub; -use subtle::{Choice, ConstantTimeEq, CtOption}; +use subtle::CtOption; -use crate::{Checked, CheckedSub, Int, Wrapping}; +use crate::{Checked, CheckedSub, ConstChoice, ConstCtOption, Int, Wrapping}; -impl CheckedSub for Int { - fn checked_sub(&self, rhs: &Self) -> CtOption { +impl Int { + /// Perform subtraction, returning the result along with a [ConstChoice] which `is_true` + /// only if the operation underflowed. + pub const fn underflowing_sub(&self, rhs: &Self) -> (Self, ConstChoice) { // Step 1. subtract operands let res = Self(self.0.wrapping_sub(&rhs.0)); @@ -18,12 +20,26 @@ impl CheckedSub for Int { // - underflow occurs if and only if the result and the lhs have opposing signs. // // We can thus express the overflow flag as: (self.msb != rhs.msb) & (self.msb != res.msb) - let self_msb: Choice = self.is_negative().into(); - let underflow = - self_msb.ct_ne(&rhs.is_negative().into()) & self_msb.ct_ne(&res.is_negative().into()); + let self_msb = self.is_negative(); + let underflow = self_msb + .ne(rhs.is_negative()) + .and(self_msb.ne(res.is_negative())); // Step 3. Construct result - CtOption::new(res, !underflow) + (res, underflow) + } + + /// Perform wrapping subtraction, discarding underflow and wrapping around the boundary of the + /// type. + pub const fn wrapping_sub(&self, rhs: &Self) -> Self { + self.underflowing_sub(rhs).0 + } +} + +impl CheckedSub for Int { + fn checked_sub(&self, rhs: &Self) -> CtOption { + let (res, underflow) = Self::underflowing_sub(self, rhs); + ConstCtOption::new(res, underflow.not()).into() } } @@ -79,8 +95,6 @@ mod tests { #[cfg(test)] mod tests { - use num_traits::WrappingSub; - use crate::{CheckedSub, Int, I128, U128}; #[test] diff --git a/src/modular.rs b/src/modular.rs index 1159d6a4..ca1d46a3 100644 --- a/src/modular.rs +++ b/src/modular.rs @@ -22,6 +22,7 @@ mod monty_form; mod reduction; mod add; +pub(crate) mod bingcd; mod div_by_2; mod mul; mod pow; diff --git a/src/modular/bingcd.rs b/src/modular/bingcd.rs new file mode 100644 index 00000000..e6a939e8 --- /dev/null +++ b/src/modular/bingcd.rs @@ -0,0 +1,9 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin as Algorithm 2 in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +mod extension; +mod gcd; +mod matrix; +pub(crate) mod tools; +mod xgcd; diff --git a/src/modular/bingcd/extension.rs b/src/modular/bingcd/extension.rs new file mode 100644 index 00000000..5b651f7d --- /dev/null +++ b/src/modular/bingcd/extension.rs @@ -0,0 +1,151 @@ +use crate::{ConstChoice, Int, Limb, Uint}; + +pub(crate) struct ExtendedUint( + Uint, + Uint, +); + +impl ExtendedUint { + /// Interpret `self` as an [ExtendedInt] + #[inline] + pub const fn as_extended_int(&self) -> ExtendedInt { + ExtendedInt(self.0, self.1) + } + + /// Construction the binary negation of `self`, i.e., map `self` to `!self + 1`. + /// + /// Note: maps `0` to itself. + #[inline] + pub const fn wrapping_neg(&self) -> Self { + let (lhs, carry) = self.0.carrying_neg(); + let mut rhs = self.1.not(); + rhs = Uint::select(&rhs, &rhs.wrapping_add(&Uint::ONE), carry); + Self(lhs, rhs) + } + + /// Negate `self` if `negate` is truthy. Otherwise returns `self`. + #[inline] + pub const fn wrapping_neg_if(&self, negate: ConstChoice) -> Self { + let neg = self.wrapping_neg(); + Self( + Uint::select(&self.0, &neg.0, negate), + Uint::select(&self.1, &neg.1, negate), + ) + } + + /// Shift `self` right by `shift` bits. + /// + /// Assumes `shift <= Uint::::BITS`. + #[inline] + pub const fn shr(&self, shift: u32) -> Self { + debug_assert!(shift <= Uint::::BITS); + + let shift_is_zero = ConstChoice::from_u32_eq(shift, 0); + let left_shift = shift_is_zero.select_u32(Uint::::BITS - shift, 0); + + let hi = self.1.shr(shift); + // TODO: replace with carrying_shl + let carry = Uint::select(&self.1, &Uint::ZERO, shift_is_zero).shl(left_shift); + let mut lo = self.0.shr(shift); + + // Apply carry + let limb_diff = LIMBS.wrapping_sub(EXTRA) as u32; + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + let carry = carry.resize::().shl_vartime(limb_diff * Limb::BITS); + lo = lo.bitxor(&carry); + + Self(lo, hi) + } + + /// Vartime equivalent of [Self::shr]. + #[inline] + pub const fn shr_vartime(&self, shift: u32) -> Self { + debug_assert!(shift <= Uint::::BITS); + + let shift_is_zero = ConstChoice::from_u32_eq(shift, 0); + let left_shift = shift_is_zero.select_u32(Uint::::BITS - shift, 0); + + let hi = self.1.shr_vartime(shift); + let carry = Uint::select(&self.1, &Uint::ZERO, shift_is_zero).wrapping_shl(left_shift); + let mut lo = self.0.shr_vartime(shift); + + // Apply carry + let limb_diff = LIMBS.wrapping_sub(EXTRA) as u32; + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + let carry = carry.resize::().shl_vartime(limb_diff * Limb::BITS); + lo = lo.bitxor(&carry); + + Self(lo, hi) + } +} + +pub(crate) struct ExtendedInt( + Uint, + Uint, +); + +impl ExtendedInt { + /// Construct an [ExtendedInt] from the product of a [Uint] and an [Int]. + /// + /// Assumes the top bit of the product is not set. + #[inline] + pub const fn from_product(lhs: Uint, rhs: Int) -> Self { + let (lo, hi, sgn) = rhs.split_mul_uint_right(&lhs); + ExtendedUint(lo, hi).wrapping_neg_if(sgn).as_extended_int() + } + + /// Interpret this as an [ExtendedUint]. + #[inline] + pub const fn as_extended_uint(&self) -> ExtendedUint { + ExtendedUint(self.0, self.1) + } + + /// Return the negation of `self` if `negate` is truthy. Otherwise, return `self`. + #[inline] + pub const fn wrapping_neg_if(&self, negate: ConstChoice) -> Self { + self.as_extended_uint() + .wrapping_neg_if(negate) + .as_extended_int() + } + + /// Compute `self + rhs`, wrapping any overflow. + #[inline] + pub const fn wrapping_add(&self, rhs: &Self) -> Self { + let (lo, carry) = self.0.adc(&rhs.0, Limb::ZERO); + let (hi, _) = self.1.adc(&rhs.1, carry); + Self(lo, hi) + } + + /// Returns self without the extension. + #[inline] + pub const fn wrapping_drop_extension(&self) -> (Uint, ConstChoice) { + let (abs, sgn) = self.abs_sgn(); + (abs.0, sgn) + } + + /// Decompose `self` into is absolute value and signum. + #[inline] + pub const fn abs_sgn(&self) -> (ExtendedUint, ConstChoice) { + let is_negative = self.1.as_int().is_negative(); + ( + self.wrapping_neg_if(is_negative).as_extended_uint(), + is_negative, + ) + } + + /// Divide self by `2^k`, rounding towards zero. + #[inline] + pub const fn div_2k(&self, k: u32) -> Self { + let (abs, sgn) = self.abs_sgn(); + abs.shr(k).wrapping_neg_if(sgn).as_extended_int() + } + + /// Divide self by `2^k`, rounding towards zero. + #[inline] + pub const fn div_2k_vartime(&self, k: u32) -> Self { + let (abs, sgn) = self.abs_sgn(); + abs.shr_vartime(k).wrapping_neg_if(sgn).as_extended_int() + } +} diff --git a/src/modular/bingcd/gcd.rs b/src/modular/bingcd/gcd.rs new file mode 100644 index 00000000..54daeec8 --- /dev/null +++ b/src/modular/bingcd/gcd.rs @@ -0,0 +1,237 @@ +use crate::modular::bingcd::tools::const_max; +use crate::{ConstChoice, Odd, Uint, U128, U64}; + +impl Odd> { + /// The minimal number of iterations required to ensure the Binary GCD algorithm terminates and + /// returns the proper value. + const MINIMAL_BINGCD_ITERATIONS: u32 = 2 * Self::BITS - 1; + + /// Computes `gcd(self, rhs)`, leveraging (a constant time implementation of) the classic + /// Binary GCD algorithm. + /// + /// Note: this algorithm is efficient for [Uint]s with relatively few `LIMBS`. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 1. + /// + #[inline] + pub const fn classic_bingcd(&self, rhs: &Uint) -> Self { + // (self, rhs) corresponds to (m, y) in the Algorithm 1 notation. + let (mut a, mut b) = (*rhs, *self.as_ref()); + let mut j = 0; + while j < Self::MINIMAL_BINGCD_ITERATIONS { + Self::bingcd_step(&mut a, &mut b); + j += 1; + } + + b.to_odd() + .expect("gcd of an odd value with something else is always odd") + } + + /// Binary GCD update step. + /// + /// This is a condensed, constant time execution of the following algorithm: + /// ```text + /// if a mod 2 == 1 + /// if a < b + /// (a, b) ← (b, a) + /// a ← a - b + /// a ← a/2 + /// ``` + /// + /// Note: assumes `b` to be odd. Might yield an incorrect result if this is not the case. + /// + /// Ref: Pornin, Algorithm 1, L3-9, . + #[inline] + const fn bingcd_step(a: &mut Uint, b: &mut Uint) { + let a_odd = a.is_odd(); + let a_lt_b = Uint::lt(a, b); + Uint::conditional_swap(a, b, a_odd.and(a_lt_b)); + *a = a + .wrapping_sub(&Uint::select(&Uint::ZERO, b, a_odd)) + .shr_vartime(1); + } + + /// Computes `gcd(self, rhs)`, leveraging the optimized Binary GCD algorithm. + /// + /// Note: this algorithm becomes more efficient than the classical algorithm for [Uint]s with + /// relatively many `LIMBS`. A best-effort threshold is presented in [Self::bingcd]. + /// + /// Note: the full algorithm has an additional parameter; this function selects the best-effort + /// value for this parameter. You might be able to further tune your performance by calling the + /// [Self::optimized_bingcd_] function directly. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + #[inline] + pub const fn optimized_bingcd(&self, rhs: &Uint) -> Self { + self.optimized_bingcd_::<{ U64::BITS }, { U64::LIMBS }, { U128::LIMBS }>(rhs) + } + + /// Computes `gcd(self, rhs)`, leveraging the optimized Binary GCD algorithm. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + /// + /// In summary, the optimized algorithm does not operate on `self` and `rhs` directly, but + /// instead of condensed summaries that fit in few registers. Based on these summaries, an + /// update matrix is constructed by which `self` and `rhs` are updated in larger steps. + /// + /// This function is generic over the following three values: + /// - `K`: the number of bits used when summarizing `self` and `rhs` for the inner loop. The + /// `K+1` top bits and `K-1` least significant bits are selected. It is recommended to keep + /// `K` close to a (multiple of) the number of bits that fit in a single register. + /// - `LIMBS_K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ K`, + /// - `LIMBS_2K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ 2K`. + #[inline] + pub const fn optimized_bingcd_( + &self, + rhs: &Uint, + ) -> Self { + let (mut a, mut b) = (*self.as_ref(), *rhs); + + let mut i = 0; + while i < Self::MINIMAL_BINGCD_ITERATIONS.div_ceil(K - 1) { + i += 1; + + // Construct a_ and b_ as the summary of a and b, respectively. + let n = const_max(2 * K, const_max(a.bits(), b.bits())); + let a_ = a.compact::(n); + let b_ = b.compact::(n); + + // Compute the K-1 iteration update matrix from a_ and b_ + // Safe to vartime; function executes in time variable in `iterations` only, which is + // a public constant K-1 here. + let (.., matrix, _) = a_ + .to_odd() + .expect("a_ is always odd") + .partial_binxgcd_vartime::(&b_, K - 1, ConstChoice::FALSE); + + // Update `a` and `b` using the update matrix + let (updated_a, updated_b) = matrix.extended_apply_to((a, b)); + (a, _) = updated_a.div_2k_vartime(K - 1).wrapping_drop_extension(); + (b, _) = updated_b.div_2k_vartime(K - 1).wrapping_drop_extension(); + } + + a.to_odd() + .expect("gcd of an odd value with something else is always odd") + } +} + +#[cfg(feature = "rand_core")] +#[cfg(test)] +mod tests { + + mod test_classic_bingcd { + use crate::{ + Gcd, Int, Random, Uint, U1024, U128, U192, U2048, U256, U384, U4096, U512, U64, + }; + use rand_core::OsRng; + + fn classic_bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.to_odd().unwrap().classic_bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + fn classic_bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + classic_bingcd_test(Uint::ONE, Uint::ZERO); + classic_bingcd_test(Uint::ONE, Uint::ONE); + classic_bingcd_test(Uint::ONE, Int::MAX.abs()); + classic_bingcd_test(Uint::ONE, Int::MIN.abs()); + classic_bingcd_test(Uint::ONE, Uint::MAX); + classic_bingcd_test(Int::MAX.abs(), Uint::ZERO); + classic_bingcd_test(Int::MAX.abs(), Uint::ONE); + classic_bingcd_test(Int::MAX.abs(), Int::MAX.abs()); + classic_bingcd_test(Int::MAX.abs(), Int::MIN.abs()); + classic_bingcd_test(Int::MAX.abs(), Uint::MAX); + classic_bingcd_test(Uint::MAX, Uint::ZERO); + classic_bingcd_test(Uint::MAX, Uint::ONE); + classic_bingcd_test(Uint::MAX, Int::MAX.abs()); + classic_bingcd_test(Uint::MAX, Int::MIN.abs()); + classic_bingcd_test(Uint::MAX, Uint::MAX); + + // Randomized test cases + for _ in 0..1000 { + let x = Uint::::random(&mut OsRng).bitor(&Uint::ONE); + let y = Uint::::random(&mut OsRng); + classic_bingcd_test(x, y); + } + } + + #[test] + fn test_classic_bingcd() { + classic_bingcd_tests::<{ U64::LIMBS }>(); + classic_bingcd_tests::<{ U128::LIMBS }>(); + classic_bingcd_tests::<{ U192::LIMBS }>(); + classic_bingcd_tests::<{ U256::LIMBS }>(); + classic_bingcd_tests::<{ U384::LIMBS }>(); + classic_bingcd_tests::<{ U512::LIMBS }>(); + classic_bingcd_tests::<{ U1024::LIMBS }>(); + classic_bingcd_tests::<{ U2048::LIMBS }>(); + classic_bingcd_tests::<{ U4096::LIMBS }>(); + } + } + + mod test_optimized_bingcd { + use crate::{Gcd, Int, Random, Uint, U1024, U128, U192, U2048, U256, U384, U4096, U512}; + use rand_core::OsRng; + + fn optimized_bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.to_odd().unwrap().optimized_bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + fn optimized_bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + optimized_bingcd_test(Uint::ONE, Uint::ZERO); + optimized_bingcd_test(Uint::ONE, Uint::ONE); + optimized_bingcd_test(Uint::ONE, Int::MAX.abs()); + optimized_bingcd_test(Uint::ONE, Int::MIN.abs()); + optimized_bingcd_test(Uint::ONE, Uint::MAX); + optimized_bingcd_test(Int::MAX.abs(), Uint::ZERO); + optimized_bingcd_test(Int::MAX.abs(), Uint::ONE); + optimized_bingcd_test(Int::MAX.abs(), Int::MAX.abs()); + optimized_bingcd_test(Int::MAX.abs(), Int::MIN.abs()); + optimized_bingcd_test(Int::MAX.abs(), Uint::MAX); + optimized_bingcd_test(Uint::MAX, Uint::ZERO); + optimized_bingcd_test(Uint::MAX, Uint::ONE); + optimized_bingcd_test(Uint::MAX, Int::MAX.abs()); + optimized_bingcd_test(Uint::MAX, Int::MIN.abs()); + optimized_bingcd_test(Uint::MAX, Uint::MAX); + + // Randomized testing + for _ in 0..1000 { + let x = Uint::::random(&mut OsRng).bitor(&Uint::ONE); + let y = Uint::::random(&mut OsRng); + optimized_bingcd_test(x, y); + } + } + + #[test] + fn test_optimized_bingcd() { + // Not applicable for U64 + optimized_bingcd_tests::<{ U128::LIMBS }>(); + optimized_bingcd_tests::<{ U192::LIMBS }>(); + optimized_bingcd_tests::<{ U256::LIMBS }>(); + optimized_bingcd_tests::<{ U384::LIMBS }>(); + optimized_bingcd_tests::<{ U512::LIMBS }>(); + optimized_bingcd_tests::<{ U1024::LIMBS }>(); + optimized_bingcd_tests::<{ U2048::LIMBS }>(); + optimized_bingcd_tests::<{ U4096::LIMBS }>(); + } + } +} diff --git a/src/modular/bingcd/matrix.rs b/src/modular/bingcd/matrix.rs new file mode 100644 index 00000000..ff2c6de6 --- /dev/null +++ b/src/modular/bingcd/matrix.rs @@ -0,0 +1,183 @@ +use crate::modular::bingcd::extension::ExtendedInt; +use crate::{ConstChoice, Int, Uint}; + +type Vector = (T, T); + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct IntMatrix { + pub(crate) m00: Int, + pub(crate) m01: Int, + pub(crate) m10: Int, + pub(crate) m11: Int, +} + +impl IntMatrix { + /// The unit matrix. + pub(crate) const UNIT: Self = Self::new(Int::ONE, Int::ZERO, Int::ZERO, Int::ONE); + + pub(crate) const fn new( + m00: Int, + m01: Int, + m10: Int, + m11: Int, + ) -> Self { + Self { m00, m01, m10, m11 } + } + + /// Apply this matrix to a vector of [Uint]s, returning the result as a vector of + /// [ExtendedInt]s. + #[inline] + pub(crate) const fn extended_apply_to( + &self, + vec: Vector>, + ) -> Vector> { + let (a, b) = vec; + let a0 = ExtendedInt::from_product(a, self.m00); + let a1 = ExtendedInt::from_product(a, self.m10); + let b0 = ExtendedInt::from_product(b, self.m01); + let b1 = ExtendedInt::from_product(b, self.m11); + (a0.wrapping_add(&b0), a1.wrapping_add(&b1)) + } + + /// Wrapping apply this matrix to `rhs`. Return the result in `RHS_LIMBS`. + #[inline] + pub(crate) const fn wrapping_mul_right( + &self, + rhs: &IntMatrix, + ) -> IntMatrix { + let a0 = rhs.m00.wrapping_mul(&self.m00); + let a1 = rhs.m10.wrapping_mul(&self.m01); + let a = a0.wrapping_add(&a1); + let b0 = rhs.m01.wrapping_mul(&self.m00); + let b1 = rhs.m11.wrapping_mul(&self.m01); + let b = b0.wrapping_add(&b1); + let c0 = rhs.m00.wrapping_mul(&self.m10); + let c1 = rhs.m10.wrapping_mul(&self.m11); + let c = c0.wrapping_add(&c1); + let d0 = rhs.m01.wrapping_mul(&self.m10); + let d1 = rhs.m11.wrapping_mul(&self.m11); + let d = d0.wrapping_add(&d1); + IntMatrix::new(a, b, c, d) + } + + /// Swap the rows of this matrix if `swap` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_swap_rows(&mut self, swap: ConstChoice) { + Int::conditional_swap(&mut self.m00, &mut self.m10, swap); + Int::conditional_swap(&mut self.m01, &mut self.m11, swap); + } + + /// Swap the rows of this matrix. + #[inline] + pub(crate) const fn swap_rows(&mut self) { + self.conditional_swap_rows(ConstChoice::TRUE) + } + + /// Subtract the bottom row from the top if `subtract` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_subtract_bottom_row_from_top(&mut self, subtract: ConstChoice) { + self.m00 = Int::select(&self.m00, &self.m00.wrapping_sub(&self.m10), subtract); + self.m01 = Int::select(&self.m01, &self.m01.wrapping_sub(&self.m11), subtract); + } + + /// Double the bottom row of this matrix if `double` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_double_bottom_row(&mut self, double: ConstChoice) { + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + self.m10 = Int::select(&self.m10, &self.m10.shl_vartime(1), double); + self.m11 = Int::select(&self.m11, &self.m11.shl_vartime(1), double); + } + + /// Negate the elements in the top row if `negate` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_negate_top_row(&mut self, negate: ConstChoice) { + self.m00 = self.m00.wrapping_neg_if(negate); + self.m01 = self.m01.wrapping_neg_if(negate); + } + + /// Negate the elements in the bottom row if `negate` is truthy. Otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_negate_bottom_row(&mut self, negate: ConstChoice) { + self.m10 = self.m10.wrapping_neg_if(negate); + self.m11 = self.m11.wrapping_neg_if(negate); + } +} + +#[cfg(test)] +mod tests { + use crate::modular::bingcd::matrix::IntMatrix; + use crate::{ConstChoice, Int, I256, U256}; + + const X: IntMatrix<{ U256::LIMBS }> = IntMatrix::new( + Int::from_i64(1i64), + Int::from_i64(7i64), + Int::from_i64(23i64), + Int::from_i64(53i64), + ); + + #[test] + fn test_conditional_swap() { + let mut y = X.clone(); + y.conditional_swap_rows(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_swap_rows(ConstChoice::TRUE); + assert_eq!( + y, + IntMatrix::new( + Int::from(23i32), + Int::from(53i32), + Int::from(1i32), + Int::from(7i32) + ) + ); + } + + #[test] + fn test_conditional_subtract() { + let mut y = X.clone(); + y.conditional_subtract_bottom_row_from_top(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_subtract_bottom_row_from_top(ConstChoice::TRUE); + assert_eq!( + y, + IntMatrix::new( + Int::from(-22i32), + Int::from(-46i32), + Int::from(23i32), + Int::from(53i32) + ) + ); + } + + #[test] + fn test_conditional_double() { + let mut y = X.clone(); + y.conditional_double_bottom_row(ConstChoice::FALSE); + assert_eq!(y, X); + y.conditional_double_bottom_row(ConstChoice::TRUE); + assert_eq!( + y, + IntMatrix::new( + Int::from(1i32), + Int::from(7i32), + Int::from(46i32), + Int::from(106i32), + ) + ); + } + + #[test] + fn test_wrapping_mul() { + let res = X.wrapping_mul_right(&X); + assert_eq!( + res, + IntMatrix::new( + I256::from_i64(162i64), + I256::from_i64(378i64), + I256::from_i64(1242i64), + I256::from_i64(2970i64), + ) + ) + } +} diff --git a/src/modular/bingcd/tools.rs b/src/modular/bingcd/tools.rs new file mode 100644 index 00000000..5255516a --- /dev/null +++ b/src/modular/bingcd/tools.rs @@ -0,0 +1,116 @@ +use crate::{ConstChoice, Int, Odd, Uint}; + +/// `const` equivalent of `u32::max(a, b)`. +pub(crate) const fn const_max(a: u32, b: u32) -> u32 { + ConstChoice::from_u32_lt(a, b).select_u32(a, b) +} + +/// `const` equivalent of `u32::min(a, b)`. +pub(crate) const fn const_min(a: u32, b: u32) -> u32 { + ConstChoice::from_u32_lt(a, b).select_u32(b, a) +} + +impl Int { + /// Compute `self / 2^k mod q`. Executes in time variable in `k_bound`. This value should be + /// chosen as an inclusive upperbound to the value of `k`. + #[inline] + pub(crate) const fn div_2k_mod_q(&self, k: u32, k_bound: u32, q: &Odd>) -> Self { + let (abs, sgn) = self.abs_sign(); + let abs_div_2k_mod_q = abs.div_2k_mod_q(k, k_bound, q); + Int::new_from_abs_sign(abs_div_2k_mod_q, sgn).expect("no overflow") + } +} + +impl Uint { + /// Compute `self / 2^k mod q`. + /// + /// Executes in time variable in `k_bound`. This value should be + /// chosen as an inclusive upperbound to the value of `k`. + #[inline] + const fn div_2k_mod_q(mut self, k: u32, k_bound: u32, q: &Odd) -> Self { + // 1 / 2 mod q + // = (q + 1) / 2 mod q + // = (q - 1) / 2 + 1 mod q + // = floor(q / 2) + 1 mod q, since q is odd. + let one_half_mod_q = q.as_ref().shr_vartime(1).wrapping_add(&Uint::ONE); + let mut i = 0; + while i < k_bound { + // Apply only while i < k + let apply = ConstChoice::from_u32_lt(i, k); + self = Self::select(&self, &self.div_2_mod_q(&one_half_mod_q), apply); + i += 1; + } + + self + } + + /// Compute `self / 2 mod q`. + #[inline] + const fn div_2_mod_q(self, half_mod_q: &Self) -> Self { + // Floor-divide self by 2. When self was odd, add back 1/2 mod q. + let add_one_half = self.is_odd(); + let floored_half = self.shr_vartime(1); + floored_half.wrapping_add(&Self::select(&Self::ZERO, half_mod_q, add_one_half)) + } + + /// Construct a [Uint] containing the bits in `self` in the range `[idx, idx + length)`. + /// + /// Assumes `length ≤ Uint::::BITS` and `idx + length ≤ Self::BITS`. + /// + /// Executes in time variable in `length` only. + #[inline(always)] + pub(crate) const fn section_vartime_length( + &self, + idx: u32, + length: u32, + ) -> Uint { + debug_assert!(length <= Uint::::BITS); + debug_assert!(idx + length <= Self::BITS); + + let mask = Uint::ONE.shl_vartime(length).wrapping_sub(&Uint::ONE); + self.shr(idx).resize::().bitand(&mask) + } + + /// Construct a [Uint] containing the bits in `self` in the range `[idx, idx + length)`. + /// + /// Assumes `length ≤ Uint::::BITS` and `idx + length ≤ Self::BITS`. + /// + /// Executes in time variable in `idx` and `length`. + #[inline(always)] + pub(crate) const fn section_vartime( + &self, + idx: u32, + length: u32, + ) -> Uint { + debug_assert!(length <= Uint::::BITS); + debug_assert!(idx + length <= Self::BITS); + + let mask = Uint::ONE.shl_vartime(length).wrapping_sub(&Uint::ONE); + self.shr_vartime(idx) + .resize::() + .bitand(&mask) + } + + /// Compact `self` to a form containing the concatenation of its bit ranges `[0, K-1)` + /// and `[n-K-1, n)`. + /// + /// Assumes `K ≤ Uint::::BITS`, `n ≤ Self::BITS` and `n ≥ 2K`. + #[inline(always)] + pub(crate) const fn compact( + &self, + n: u32, + ) -> Uint { + debug_assert!(K <= Uint::::BITS); + debug_assert!(n <= Self::BITS); + debug_assert!(n >= 2 * K); + + // safe to vartime; this function is vartime in length only, which is a public constant + let hi = self.section_vartime_length(n - K - 1, K + 1); + // safe to vartime; this function is vartime in idx and length only, which are both public + // constants + let lo = self.section_vartime(0, K - 1); + // safe to vartime; shl_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + hi.shl_vartime(K - 1).bitxor(&lo) + } +} diff --git a/src/modular/bingcd/xgcd.rs b/src/modular/bingcd/xgcd.rs new file mode 100644 index 00000000..70ed405d --- /dev/null +++ b/src/modular/bingcd/xgcd.rs @@ -0,0 +1,813 @@ +use crate::modular::bingcd::matrix::IntMatrix; +use crate::modular::bingcd::tools::const_max; +use crate::{ConstChoice, Int, NonZero, Odd, Uint, U128, U64}; + +/// Container for the raw output of the Binary XGCD algorithm. +pub(crate) struct RawBinxgcdOutput { + lhs: Odd>, + rhs: Odd>, + gcd: Odd>, + matrix: IntMatrix, + k: u32, + k_upper_bound: u32, +} + +impl RawBinxgcdOutput { + /// Process raw output, constructing an OddBinXgcdOutput object. + const fn process(&self) -> OddBinxgcdUintOutput { + let (x, y) = self.derive_bezout_coefficients(); + let (lhs_on_gcd, rhs_on_gcd) = self.extract_quotients(); + OddBinxgcdUintOutput { + gcd: self.gcd, + x, + y, + lhs_on_gcd, + rhs_on_gcd, + } + } + + /// Extract the Bézout coefficients from `matrix`, where it is assumed that + /// `matrix * (lhs, rhs) = (gcd * 2^k, 0)`. + const fn derive_bezout_coefficients(&self) -> (Int, Int) { + // The Bézout coefficients `x` and `y` can be extracted from `matrix.m00` and `matrix.m01`, + // respectively. In fact, `matrix.m00 = x * 2^k` and `matrix.m01 = y * 2^k`. + // Hence, we can compute + // `x = matrix.m00 / 2^k mod rhs`, and + // `y = matrix.m01 / 2^k mod lhs`. + let (x, y) = (self.matrix.m00, self.matrix.m01); + ( + x.div_2k_mod_q(self.k, self.k_upper_bound, &self.rhs), + y.div_2k_mod_q(self.k, self.k_upper_bound, &self.lhs), + ) + } + + /// Mutably borrow the quotients `lhs/gcd` and `rhs/gcd`. + const fn quotients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.matrix.m11, &mut self.matrix.m10) + } + + /// Extract the quotients `lhs/gcd` and `rhs/gcd` from `matrix`. + const fn extract_quotients(&self) -> (Uint, Uint) { + let lhs_on_gcd = self.matrix.m11.abs(); + let rhs_on_gcd = self.matrix.m10.abs(); + (lhs_on_gcd, rhs_on_gcd) + } +} + +/// Container for the processed output of the Binary XGCD algorithm. +pub(crate) struct OddBinxgcdUintOutput { + pub(crate) gcd: Odd>, + x: Int, + y: Int, + lhs_on_gcd: Uint, + rhs_on_gcd: Uint, +} + +impl OddBinxgcdUintOutput { + /// Obtain a copy of the Bézout coefficients. + pub(crate) const fn bezout_coefficients(&self) -> (Int, Int) { + (self.x, self.y) + } + + /// Mutably borrow the Bézout coefficients. + const fn bezout_coefficients_as_mut(&mut self) -> (&mut Int, &mut Int) { + (&mut self.x, &mut self.y) + } + + /// Obtain a copy of the quotients `lhs/gcd` and `rhs/gcd`. + pub(crate) const fn quotients(&self) -> (Uint, Uint) { + (self.lhs_on_gcd, self.rhs_on_gcd) + } +} + +impl Odd> { + /// The minimal number of binary GCD iterations required to guarantee successful completion. + const MIN_BINGCD_ITERATIONS: u32 = 2 * Self::BITS - 1; + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + /// + /// **Warning**: this algorithm is only guaranteed to work for `self` and `rhs` for which the + /// msb is **not** set. May panic otherwise. + pub(crate) const fn binxgcd_nz( + &self, + rhs: &NonZero>, + ) -> OddBinxgcdUintOutput { + // Note that for the `binxgcd` subroutine, `rhs` needs to be odd. + // + // We use the fact that gcd(a, b) = gcd(a, |a-b|) to + // 1) convert the input (self, rhs) to (self, rhs') where rhs' is guaranteed odd, + // 2) execute the xgcd algorithm on (self, rhs'), and + // 3) recover the Bezout coefficients for (self, rhs) from the (self, rhs') output. + + let (abs_lhs_sub_rhs, rhs_gt_lhs) = + self.as_ref().wrapping_sub(rhs.as_ref()).as_int().abs_sign(); + let rhs_is_even = rhs.as_ref().is_odd().not(); + let rhs_ = Uint::select(rhs.as_ref(), &abs_lhs_sub_rhs, rhs_is_even) + .to_odd() + .expect("rhs is odd by construction"); + + let mut output = self.binxgcd(&rhs_); + let (u, v) = output.quotients_as_mut(); + + // At this point, we have one of the following three situations: + // i. 0 = lhs * v + (rhs - lhs) * u, if rhs is even and rhs > lhs + // ii. 0 = lhs * v + (lhs - rhs) * u, if rhs is even and rhs < lhs + // iii. 0 = lhs * v + rhs * u, if rhs is odd + + // We can rearrange these terms to get the quotients to (self, rhs) as follows: + // i. gcd = lhs * (v - u) + rhs * u, if rhs is even and rhs > lhs + // ii. gcd = lhs * (v + u) - u * rhs, if rhs is even and rhs < lhs + // iii. gcd = lhs * v + rhs * u, if rhs is odd + + *v = Int::select(v, &v.wrapping_sub(u), rhs_is_even.and(rhs_gt_lhs)); + *v = Int::select(v, &v.wrapping_add(u), rhs_is_even.and(rhs_gt_lhs.not())); + *u = u.wrapping_neg_if(rhs_is_even.and(rhs_gt_lhs.not())); + + let mut processed_output = output.process(); + let (x, y) = processed_output.bezout_coefficients_as_mut(); + + // At this point, we have one of the following three situations: + // i. gcd = lhs * x + (rhs - lhs) * y, if rhs is even and rhs > lhs + // ii. gcd = lhs * x + (lhs - rhs) * y, if rhs is even and rhs < lhs + // iii. gcd = lhs * x + rhs * y, if rhs is odd + + // We can rearrange these terms to get the Bezout coefficients to (self, rhs) as follows: + // i. gcd = lhs * (x - y) + rhs * y, if rhs is even and rhs > lhs + // ii. gcd = lhs * (x + y) - y * rhs, if rhs is even and rhs < lhs + // iii. gcd = lhs * x + rhs * y, if rhs is odd + + *x = Int::select(x, &x.wrapping_sub(y), rhs_is_even.and(rhs_gt_lhs)); + *x = Int::select(x, &x.wrapping_add(y), rhs_is_even.and(rhs_gt_lhs.not())); + *y = y.wrapping_neg_if(rhs_is_even.and(rhs_gt_lhs.not())); + + processed_output + } + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + /// + /// **Warning**: this algorithm is only guaranteed to work for `self` and `rhs` for which the + /// msb is **not** set. May panic otherwise. + /// + /// This function switches between the "classic" and "optimized" algorithm at a best-effort + /// threshold. When using [Uint]s with `LIMBS` close to the threshold, it may be useful to + /// manually test whether the classic or optimized algorithm is faster for your machine. + pub(crate) const fn binxgcd(&self, rhs: &Self) -> RawBinxgcdOutput { + if LIMBS < 4 { + self.classic_binxgcd(rhs) + } else { + self.optimized_binxgcd(rhs) + } + } + + /// Execute the classic Binary Extended GCD algorithm. + /// + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`. + /// + /// **Warning**: this algorithm is only guaranteed to work for `self` and `rhs` for which the + /// msb is **not** set. May panic otherwise. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 1. + /// . + pub(crate) const fn classic_binxgcd(&self, rhs: &Self) -> RawBinxgcdOutput { + let (gcd, _, matrix, total_doublings) = self.partial_binxgcd_vartime::( + rhs.as_ref(), + Self::MIN_BINGCD_ITERATIONS, + ConstChoice::TRUE, + ); + + RawBinxgcdOutput { + lhs: *self, + rhs: *rhs, + gcd, + matrix, + k: total_doublings, + k_upper_bound: Self::MIN_BINGCD_ITERATIONS, + } + } + + /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the Binary Extended GCD algorithm. + /// + /// **Warning**: this algorithm is only guaranteed to work for `self` and `rhs` for which the + /// msb is **not** set. May panic otherwise. Furthermore, at `self` and `rhs` must contain at + /// least 128 bits. + /// + /// Note: this algorithm becomes more efficient than the classical algorithm for [Uint]s with + /// relatively many `LIMBS`. A best-effort threshold is presented in [Self::binxgcd]. + /// + /// Note: the full algorithm has an additional parameter; this function selects the best-effort + /// value for this parameter. You might be able to further tune your performance by calling the + /// [Self::optimized_bingcd_] function directly. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// . + pub(crate) const fn optimized_binxgcd(&self, rhs: &Self) -> RawBinxgcdOutput { + assert!(Self::BITS >= U128::BITS); + self.optimized_binxgcd_::<{ U64::BITS }, { U64::LIMBS }, { U128::LIMBS }>(rhs) + } + + /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`, + /// leveraging the optimized Binary Extended GCD algorithm. + /// + /// **Warning**: this algorithm is only guaranteed to work for `self` and `rhs` for which the + /// msb is **not** set. May panic otherwise. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// + /// + /// In summary, the optimized algorithm does not operate on `self` and `rhs` directly, but + /// instead of condensed summaries that fit in few registers. Based on these summaries, an + /// update matrix is constructed by which `self` and `rhs` are updated in larger steps. + /// + /// This function is generic over the following three values: + /// - `K`: the number of bits used when summarizing `self` and `rhs` for the inner loop. The + /// `K+1` top bits and `K-1` least significant bits are selected. It is recommended to keep + /// `K` close to a (multiple of) the number of bits that fit in a single register. + /// - `LIMBS_K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ K`, + /// - `LIMBS_2K`: should be chosen as the minimum number s.t. `Uint::::BITS ≥ 2K`. + pub(crate) const fn optimized_binxgcd_< + const K: u32, + const LIMBS_K: usize, + const LIMBS_2K: usize, + >( + &self, + rhs: &Self, + ) -> RawBinxgcdOutput { + let (mut a, mut b) = (*self.as_ref(), *rhs.as_ref()); + let mut matrix = IntMatrix::UNIT; + let mut total_doublings = 0; + + let (mut a_sgn, mut b_sgn); + let mut i = 0; + while i < Self::MIN_BINGCD_ITERATIONS.div_ceil(K - 1) { + i += 1; + + // Construct a_ and b_ as the summary of a and b, respectively. + let b_bits = b.bits(); + let n = const_max(2 * K, const_max(a.bits(), b_bits)); + let a_ = a.compact::(n); + let b_ = b.compact::(n); + let b_fits_in_compact = + ConstChoice::from_u32_le(b_bits, K - 1).or(ConstChoice::from_u32_eq(n, 2 * K)); + + // Compute the K-1 iteration update matrix from a_ and b_ + let (.., update_matrix, doublings) = a_ + .to_odd() + .expect("a is always odd") + .partial_binxgcd_vartime::(&b_, K - 1, b_fits_in_compact); + + // Update `a` and `b` using the update matrix + let (updated_a, updated_b) = update_matrix.extended_apply_to((a, b)); + (a, a_sgn) = updated_a.div_2k(doublings).wrapping_drop_extension(); + (b, b_sgn) = updated_b.div_2k(doublings).wrapping_drop_extension(); + + matrix = update_matrix.wrapping_mul_right(&matrix); + matrix.conditional_negate_top_row(a_sgn); + matrix.conditional_negate_bottom_row(b_sgn); + total_doublings += doublings; + } + + let gcd = a + .to_odd() + .expect("gcd of an odd value with something else is always odd"); + + RawBinxgcdOutput { + lhs: *self, + rhs: *rhs, + gcd, + matrix, + k: total_doublings, + k_upper_bound: Self::MIN_BINGCD_ITERATIONS, + } + } + + /// Executes the optimized Binary GCD inner loop. + /// + /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2. + /// . + /// + /// The function outputs the reduced values `(a, b)` for the input values `(self, rhs)` as well + /// as the matrix that yields the former two when multiplied with the latter two. + /// + /// Additionally, the number doublings that were executed is returned. By construction, each + /// element in `M` lies in the interval `(-2^doublings, 2^doublings]`. + /// + /// Note: this implementation deviates slightly from the paper, in that it can be instructed to + /// "run in place" (i.e., execute iterations that do nothing) once `a` becomes zero. + /// This is done by passing a truthy `halt_at_zero`. + /// + /// The function executes in time variable in `iterations`. + #[inline] + pub(crate) const fn partial_binxgcd_vartime( + &self, + rhs: &Uint, + iterations: u32, + halt_at_zero: ConstChoice, + ) -> (Self, Uint, IntMatrix, u32) { + let (mut a, mut b) = (*self.as_ref(), *rhs); + // This matrix corresponds with (f0, g0, f1, g1) in the paper. + let mut matrix = IntMatrix::UNIT; + + // Compute the update matrix. + // Note: to be consistent with the paper, the `binxgcd_step` algorithm requires the second + // argument to be odd. Here, we have `a` odd, so we have to swap a and b before and after + // calling the subroutine. The columns of the matrix have to be swapped accordingly. + Uint::swap(&mut a, &mut b); + matrix.swap_rows(); + + let mut doublings = 0; + let mut j = 0; + while j < iterations { + Self::binxgcd_step::( + &mut a, + &mut b, + &mut matrix, + &mut doublings, + halt_at_zero, + ); + j += 1; + } + + // Undo swap + Uint::swap(&mut a, &mut b); + matrix.swap_rows(); + + let a = a.to_odd().expect("a is always odd"); + (a, b, matrix, doublings) + } + + /// Binary XGCD update step. + /// + /// This is a condensed, constant time execution of the following algorithm: + /// ```text + /// if a mod 2 == 1 + /// if a < b + /// (a, b) ← (b, a) + /// (f0, g0, f1, g1) ← (f1, g1, f0, g0) + /// a ← a - b + /// (f0, g0) ← (f0 - f1, g0 - g1) + /// if a > 0 + /// a ← a/2 + /// (f1, g1) ← (2f1, 2g1) + /// ``` + /// where `matrix` represents + /// ```text + /// (f0 g0) + /// (f1 g1). + /// ``` + /// + /// Note: this algorithm assumes `b` to be an odd integer. The algorithm will likely not yield + /// the correct result when this is not the case. + /// + /// Ref: Pornin, Algorithm 2, L8-17, . + #[inline] + const fn binxgcd_step( + a: &mut Uint, + b: &mut Uint, + matrix: &mut IntMatrix, + executed_iterations: &mut u32, + halt_at_zero: ConstChoice, + ) { + let a_odd = a.is_odd(); + let a_lt_b = Uint::lt(a, b); + + // swap if a odd and a < b + let swap = a_odd.and(a_lt_b); + Uint::conditional_swap(a, b, swap); + matrix.conditional_swap_rows(swap); + + // subtract b from a when a is odd + *a = a.wrapping_sub(&Uint::select(&Uint::ZERO, b, a_odd)); + matrix.conditional_subtract_bottom_row_from_top(a_odd); + + // Div a by 2. + let double = a.is_nonzero().or(halt_at_zero.not()); + // safe to vartime; shr_vartime is variable in the value of shift only. Since this shift + // is a public constant, the constant time property of this algorithm is not impacted. + *a = a.shr_vartime(1); + + // Double the bottom row of the matrix when a was ≠ 0 and when not halting. + matrix.conditional_double_bottom_row(double); + // Something happened in this iteration only when a was non-zero before being halved. + *executed_iterations = double.select_u32(*executed_iterations, *executed_iterations + 1); + } +} + +#[cfg(test)] +mod tests { + use crate::modular::bingcd::xgcd::OddBinxgcdUintOutput; + use crate::{ConcatMixed, Gcd, Uint}; + use core::ops::Div; + + mod test_extract_quotients { + use crate::modular::bingcd::matrix::IntMatrix; + use crate::modular::bingcd::xgcd::RawBinxgcdOutput; + use crate::{Int, Uint, U64}; + + fn raw_binxgcdoutput_setup( + matrix: IntMatrix, + ) -> RawBinxgcdOutput { + RawBinxgcdOutput { + lhs: Uint::::ONE.to_odd().unwrap(), + rhs: Uint::::ONE.to_odd().unwrap(), + gcd: Uint::::ONE.to_odd().unwrap(), + matrix, + k: 0, + k_upper_bound: 0, + } + } + + #[test] + fn test_extract_quotients_unit() { + let output = raw_binxgcdoutput_setup(IntMatrix::<{ U64::LIMBS }>::UNIT); + let (lhs_on_gcd, rhs_on_gcd) = output.extract_quotients(); + assert_eq!(lhs_on_gcd, Uint::ONE); + assert_eq!(rhs_on_gcd, Uint::ZERO); + } + + #[test] + fn test_extract_quotients_basic() { + let output = raw_binxgcdoutput_setup(IntMatrix::<{ U64::LIMBS }>::new( + Int::ZERO, + Int::ZERO, + Int::from(5i32), + Int::from(-7i32), + )); + let (lhs_on_gcd, rhs_on_gcd) = output.extract_quotients(); + assert_eq!(lhs_on_gcd, Uint::from(7u32)); + assert_eq!(rhs_on_gcd, Uint::from(5u32)); + + let output = raw_binxgcdoutput_setup(IntMatrix::<{ U64::LIMBS }>::new( + Int::ZERO, + Int::ZERO, + Int::from(-7i32), + Int::from(5i32), + )); + let (lhs_on_gcd, rhs_on_gcd) = output.extract_quotients(); + assert_eq!(lhs_on_gcd, Uint::from(5u32)); + assert_eq!(rhs_on_gcd, Uint::from(7u32)); + } + } + + mod test_derive_bezout_coefficients { + use crate::modular::bingcd::matrix::IntMatrix; + use crate::modular::bingcd::xgcd::RawBinxgcdOutput; + use crate::{Int, Uint, I64, U64}; + + #[test] + fn test_derive_bezout_coefficients_unit() { + let output = RawBinxgcdOutput { + lhs: Uint::ONE.to_odd().unwrap(), + rhs: Uint::ONE.to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::<{ U64::LIMBS }>::UNIT, + k: 0, + k_upper_bound: 0, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::ONE); + assert_eq!(y, Int::ZERO); + } + + #[test] + fn test_derive_bezout_coefficients_basic() { + let output = RawBinxgcdOutput { + lhs: Uint::ONE.to_odd().unwrap(), + rhs: Uint::ONE.to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::new( + I64::from(2i32), + I64::from(3i32), + I64::from(4i32), + I64::from(5i32), + ), + k: 0, + k_upper_bound: 0, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::from(2i32)); + assert_eq!(y, Int::from(3i32)); + + let output = RawBinxgcdOutput { + lhs: Uint::ONE.to_odd().unwrap(), + rhs: Uint::ONE.to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::new( + I64::from(2i32), + I64::from(3i32), + I64::from(4i32), + I64::from(5i32), + ), + k: 0, + k_upper_bound: 1, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::from(2i32)); + assert_eq!(y, Int::from(3i32)); + } + + #[test] + fn test_derive_bezout_coefficients_removes_doublings_easy() { + let output = RawBinxgcdOutput { + lhs: Uint::ONE.to_odd().unwrap(), + rhs: Uint::ONE.to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::new( + I64::from(2i32), + I64::from(6i32), + I64::from(4i32), + I64::from(5i32), + ), + k: 1, + k_upper_bound: 1, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::ONE); + assert_eq!(y, Int::from(3i32)); + + let output = RawBinxgcdOutput { + lhs: Uint::ONE.to_odd().unwrap(), + rhs: Uint::ONE.to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::new( + I64::from(120i32), + I64::from(64i32), + I64::from(4i32), + I64::from(5i32), + ), + k: 5, + k_upper_bound: 6, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::from(4i32)); + assert_eq!(y, Int::from(2i32)); + } + + #[test] + fn test_derive_bezout_coefficients_removes_doublings_for_odd_numbers() { + let output = RawBinxgcdOutput { + lhs: Uint::from(7u32).to_odd().unwrap(), + rhs: Uint::from(5u32).to_odd().unwrap(), + gcd: Uint::ONE.to_odd().unwrap(), + matrix: IntMatrix::new( + I64::from(2i32), + I64::from(6i32), + I64::from(4i32), + I64::from(5i32), + ), + k: 3, + k_upper_bound: 7, + }; + let (x, y) = output.derive_bezout_coefficients(); + assert_eq!(x, Int::from(4i32)); + assert_eq!(y, Int::from(6i32)); + } + } + + mod test_partial_binxgcd { + use crate::modular::bingcd::matrix::IntMatrix; + use crate::{ConstChoice, Odd, I64, U64}; + + const A: Odd = U64::from_be_hex("CA048AFA63CD6A1F").to_odd().expect("odd"); + const B: U64 = U64::from_be_hex("AE693BF7BE8E5566"); + + #[test] + fn test_partial_binxgcd() { + let (.., matrix, iters) = + A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&B, 5, ConstChoice::TRUE); + assert_eq!(iters, 5); + assert_eq!( + matrix, + IntMatrix::new(I64::from(8), I64::from(-4), I64::from(-2), I64::from(5)) + ); + } + + #[test] + fn test_partial_binxgcd_constructs_correct_matrix() { + let (new_a, new_b, matrix, _) = + A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&B, 5, ConstChoice::TRUE); + + let (computed_a, computed_b) = matrix.extended_apply_to((A.get(), B)); + let computed_a = computed_a.div_2k(5).wrapping_drop_extension().0; + let computed_b = computed_b.div_2k(5).wrapping_drop_extension().0; + + assert_eq!(new_a.get(), computed_a); + assert_eq!(new_b, computed_b); + } + + const SMALL_A: Odd = U64::from_be_hex("0000000003CD6A1F").to_odd().expect("odd"); + const SMALL_B: U64 = U64::from_be_hex("000000000E8E5566"); + + #[test] + fn test_partial_binxgcd_halts() { + let (gcd, .., iters) = + SMALL_A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&SMALL_B, 60, ConstChoice::TRUE); + assert_eq!(iters, 35); + assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B)); + } + + #[test] + fn test_partial_binxgcd_does_not_halt() { + let (gcd, .., iters) = + SMALL_A.partial_binxgcd_vartime::<{ U64::LIMBS }>(&SMALL_B, 60, ConstChoice::FALSE); + assert_eq!(iters, 60); + assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B)); + } + } + + /// Helper function to effectively test xgcd. + fn test_xgcd( + lhs: Uint, + rhs: Uint, + output: OddBinxgcdUintOutput, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Test the gcd + assert_eq!(lhs.gcd(&rhs), output.gcd); + + // Test the quotients + assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.as_nz_ref())); + assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.as_nz_ref())); + + // Test the Bezout coefficients + let (x, y) = output.bezout_coefficients(); + assert_eq!( + x.widening_mul_uint(&lhs) + y.widening_mul_uint(&rhs), + output.gcd.resize().as_int(), + ); + } + + mod test_binxgcd_nz { + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::{ + ConcatMixed, Gcd, Int, RandomMod, Uint, U1024, U128, U192, U2048, U256, U384, U4096, + U512, U64, U768, U8192, + }; + use rand_core::OsRng; + + fn binxgcd_nz_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let output = lhs.to_odd().unwrap().binxgcd_nz(&rhs.to_nz().unwrap()); + test_xgcd(lhs, rhs, output); + } + + fn binxgcd_nz_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let odd_upper_bound = *Int::MAX.as_uint(); + let even_upper_bound = Int::MIN.abs(); + binxgcd_nz_test(Uint::ONE, Uint::ONE); + binxgcd_nz_test(Uint::ONE, odd_upper_bound); + binxgcd_nz_test(Uint::ONE, even_upper_bound); + binxgcd_nz_test(odd_upper_bound, Uint::ONE); + binxgcd_nz_test(odd_upper_bound, odd_upper_bound); + binxgcd_nz_test(odd_upper_bound, even_upper_bound); + + // Randomized test cases + let bound = Int::MIN.as_uint().to_nz().unwrap(); + for _ in 0..100 { + let x = Uint::::random_mod(&mut OsRng, &bound).bitor(&Uint::ONE); + let y = Uint::::random_mod(&mut OsRng, &bound).saturating_add(&Uint::ONE); + binxgcd_nz_test(x, y); + } + } + + #[test] + fn test_binxgcd_nz() { + binxgcd_nz_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + binxgcd_nz_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + binxgcd_nz_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + binxgcd_nz_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + binxgcd_nz_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + binxgcd_nz_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + binxgcd_nz_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + binxgcd_nz_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + binxgcd_nz_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_classic_binxgcd { + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::{ + ConcatMixed, Gcd, Int, RandomMod, Uint, U1024, U128, U192, U2048, U256, U384, U4096, + U512, U64, U768, U8192, + }; + use rand_core::OsRng; + + fn classic_binxgcd_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let output = lhs + .to_odd() + .unwrap() + .classic_binxgcd(&rhs.to_odd().unwrap()); + test_xgcd(lhs, rhs, output.process()); + } + + fn classic_binxgcd_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let upper_bound = *Int::MAX.as_uint(); + classic_binxgcd_test(Uint::ONE, Uint::ONE); + classic_binxgcd_test(Uint::ONE, upper_bound); + classic_binxgcd_test(upper_bound, Uint::ONE); + classic_binxgcd_test(upper_bound, upper_bound); + + // Randomized test cases + let bound = Int::MIN.as_uint().to_nz().unwrap(); + for _ in 0..100 { + let x = Uint::::random_mod(&mut OsRng, &bound).bitor(&Uint::ONE); + let y = Uint::::random_mod(&mut OsRng, &bound).bitor(&Uint::ONE); + classic_binxgcd_test(x, y); + } + } + + #[test] + fn test_classic_binxgcd() { + classic_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>(); + classic_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + classic_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + classic_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + classic_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + classic_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + classic_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + classic_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + classic_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } + + mod test_optimized_binxgcd { + use crate::modular::bingcd::xgcd::tests::test_xgcd; + use crate::{ + ConcatMixed, Gcd, Int, RandomMod, Uint, U1024, U128, U192, U2048, U256, U384, U4096, + U512, U768, U8192, + }; + use rand_core::OsRng; + + fn optimized_binxgcd_test( + lhs: Uint, + rhs: Uint, + ) where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + let output = lhs + .to_odd() + .unwrap() + .optimized_binxgcd(&rhs.to_odd().unwrap()); + test_xgcd(lhs, rhs, output.process()); + } + + fn optimized_binxgcd_tests() + where + Uint: + Gcd> + ConcatMixed, MixedOutput = Uint>, + { + // Edge cases + let upper_bound = *Int::MAX.as_uint(); + optimized_binxgcd_test(Uint::ONE, Uint::ONE); + optimized_binxgcd_test(Uint::ONE, upper_bound); + optimized_binxgcd_test(upper_bound, Uint::ONE); + optimized_binxgcd_test(upper_bound, upper_bound); + + // Randomized test cases + let bound = Int::MIN.as_uint().to_nz().unwrap(); + for _ in 0..100 { + let x = Uint::::random_mod(&mut OsRng, &bound).bitor(&Uint::ONE); + let y = Uint::::random_mod(&mut OsRng, &bound).bitor(&Uint::ONE); + optimized_binxgcd_test(x, y); + } + } + + #[test] + fn test_optimized_binxgcd() { + optimized_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>(); + optimized_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>(); + optimized_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>(); + optimized_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>(); + optimized_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>(); + optimized_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>(); + optimized_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>(); + optimized_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>(); + } + } +} diff --git a/src/non_zero.rs b/src/non_zero.rs index fcce2dba..b93dc997 100644 --- a/src/non_zero.rs +++ b/src/non_zero.rs @@ -177,6 +177,11 @@ impl NonZero> { // Note: a NonZero always has a non-zero magnitude, so it is safe to unwrap. (NonZero::>::new_unwrap(abs), sign) } + + /// Convert a [`NonZero`] to its [`NonZero`] magnitude. + pub const fn abs(&self) -> NonZero> { + self.abs_sign().0 + } } #[cfg(feature = "hybrid-array")] diff --git a/src/odd.rs b/src/odd.rs index a7995d02..87aa3dc4 100644 --- a/src/odd.rs +++ b/src/odd.rs @@ -1,6 +1,6 @@ //! Wrapper type for non-zero integers. -use crate::{Integer, Limb, NonZero, Uint}; +use crate::{Int, Integer, Limb, NonZero, Uint}; use core::{cmp::Ordering, fmt, ops::Deref}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -58,6 +58,9 @@ impl Odd { } impl Odd> { + /// Total size of the represented integer in bits. + pub const BITS: u32 = Uint::::BITS; + /// Create a new [`Odd>`] from the provided big endian hex string. /// /// Panics if the hex is malformed or not zero-padded accordingly for the size, or if the value is even. @@ -160,6 +163,14 @@ impl Random for Odd> { } } +#[cfg(feature = "rand_core")] +impl Random for Odd> { + /// Generate a random `Odd>`. + fn random(rng: &mut impl RngCore) -> Self { + Odd(Odd::>::random(rng).as_int()) + } +} + #[cfg(all(feature = "alloc", feature = "rand_core"))] impl Odd { /// Generate a random `Odd>`. diff --git a/src/uint.rs b/src/uint.rs index fe2208d2..a0aa1447 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -165,7 +165,7 @@ impl Uint { } /// Borrow the limbs of this [`Uint`] mutably. - pub fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { + pub const fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { &mut self.limbs } @@ -461,6 +461,7 @@ impl_uint_concat_split_mixed! { (U1024, [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15]), } +mod bingcd; #[cfg(feature = "extra-sizes")] mod extra_sizes; diff --git a/src/uint/bingcd.rs b/src/uint/bingcd.rs new file mode 100644 index 00000000..d8c425fe --- /dev/null +++ b/src/uint/bingcd.rs @@ -0,0 +1,118 @@ +//! This module implements (a constant variant of) the Optimized Extended Binary GCD algorithm, +//! which is described by Pornin as Algorithm 2 in "Optimized Binary GCD for Modular Inversion". +//! Ref: + +use crate::modular::bingcd::tools::const_min; +use crate::{NonZero, Odd, Uint}; + +impl Uint { + /// Compute the greatest common divisor of `self` and `rhs`. + pub const fn bingcd(&self, rhs: &Self) -> Self { + let self_is_zero = self.is_nonzero().not(); + let self_nz = Uint::select(self, &Uint::ONE, self_is_zero) + .to_nz() + .expect("self is non zero by construction"); + Uint::select(self_nz.bingcd(rhs).as_ref(), rhs, self_is_zero) + } +} + +impl NonZero> { + /// Compute the greatest common divisor of `self` and `rhs`. + pub const fn bingcd(&self, rhs: &Uint) -> Self { + let val = self.as_ref(); + // Leverage two GCD identity rules to make self odd. + // 1) gcd(2a, 2b) = 2 * gcd(a, b) + // 2) gcd(a, 2b) = gcd(a, b) if a is odd. + let i = val.trailing_zeros(); + let j = rhs.trailing_zeros(); + let k = const_min(i, j); + + val.shr(i) + .to_odd() + .expect("val.shr(i) is odd by construction") + .bingcd(rhs) + .as_ref() + .shl(k) + .to_nz() + .expect("gcd of non-zero element with another element is non-zero") + } +} + +impl Odd> { + /// Compute the greatest common divisor of `self` and `rhs` using the Binary GCD algorithm. + /// + /// This function switches between the "classic" and "optimized" algorithm at a best-effort + /// threshold. When using [Uint]s with `LIMBS` close to the threshold, it may be useful to + /// manually test whether the classic or optimized algorithm is faster for your machine. + #[inline(always)] + pub const fn bingcd(&self, rhs: &Uint) -> Self { + if LIMBS < 6 { + self.classic_bingcd(rhs) + } else { + self.optimized_bingcd(rhs) + } + } +} + +#[cfg(feature = "rand_core")] +#[cfg(test)] +mod tests { + use rand_core::OsRng; + + use crate::{ + Gcd, Int, Random, Uint, U1024, U128, U16384, U2048, U256, U4096, U512, U64, U8192, + }; + + fn bingcd_test(lhs: Uint, rhs: Uint) + where + Uint: Gcd>, + { + let gcd = lhs.gcd(&rhs); + let bingcd = lhs.bingcd(&rhs); + assert_eq!(gcd, bingcd); + } + + fn bingcd_tests() + where + Uint: Gcd>, + { + // Edge cases + let min = Int::MIN.abs(); + bingcd_test(Uint::ZERO, Uint::ZERO); + bingcd_test(Uint::ZERO, Uint::ONE); + bingcd_test(Uint::ZERO, min); + bingcd_test(Uint::ZERO, Uint::MAX); + bingcd_test(Uint::ONE, Uint::ZERO); + bingcd_test(Uint::ONE, Uint::ONE); + bingcd_test(Uint::ONE, min); + bingcd_test(Uint::ONE, Uint::MAX); + bingcd_test(min, Uint::ZERO); + bingcd_test(min, Uint::ONE); + bingcd_test(min, Int::MIN.abs()); + bingcd_test(min, Uint::MAX); + bingcd_test(Uint::MAX, Uint::ZERO); + bingcd_test(Uint::MAX, Uint::ONE); + bingcd_test(Uint::ONE, min); + bingcd_test(Uint::MAX, Uint::MAX); + + // Randomized test cases + for _ in 0..100 { + let x = Uint::::random(&mut OsRng); + let y = Uint::::random(&mut OsRng); + bingcd_test(x, y); + } + } + + #[test] + fn test_bingcd() { + bingcd_tests::<{ U64::LIMBS }>(); + bingcd_tests::<{ U128::LIMBS }>(); + bingcd_tests::<{ U256::LIMBS }>(); + bingcd_tests::<{ U512::LIMBS }>(); + bingcd_tests::<{ U1024::LIMBS }>(); + bingcd_tests::<{ U2048::LIMBS }>(); + bingcd_tests::<{ U4096::LIMBS }>(); + bingcd_tests::<{ U8192::LIMBS }>(); + bingcd_tests::<{ U16384::LIMBS }>(); + } +} diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs index 453f8d11..747cb959 100644 --- a/src/uint/cmp.rs +++ b/src/uint/cmp.rs @@ -25,6 +25,18 @@ impl Uint { Uint { limbs } } + /// Swap `a` and `b` if `c` is truthy, otherwise, do nothing. + #[inline] + pub(crate) const fn conditional_swap(a: &mut Self, b: &mut Self, c: ConstChoice) { + (*a, *b) = (Self::select(a, b, c), Self::select(b, a, c)); + } + + /// Swap `a` and `b`. + #[inline] + pub(crate) const fn swap(a: &mut Self, b: &mut Self) { + Self::conditional_swap(a, b, ConstChoice::TRUE) + } + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. #[inline] pub(crate) const fn is_nonzero(&self) -> ConstChoice {