use super::{Scalar, MODULUS};
use crate::ORDER;
use elliptic_curve::{
bigint::{Limb, U256, U512},
subtle::{Choice, ConditionallySelectable},
};
const NEG_MODULUS: [u64; 4] = [!MODULUS[0] + 1, !MODULUS[1], !MODULUS[2], !MODULUS[3]];
#[derive(Clone, Copy, Debug, Default)]
pub(crate) struct WideScalar(pub(super) U512);
impl WideScalar {
pub const fn from_bytes(bytes: &[u8; 64]) -> Self {
Self(U512::from_be_slice(bytes))
}
#[inline(always)] pub fn mul_wide(a: &Scalar, b: &Scalar) -> Self {
let a = a.0.to_words();
let b = b.0.to_words();
let c0 = 0;
let c1 = 0;
let c2 = 0;
let (c0, c1) = muladd_fast(a[0], b[0], c0, c1);
let (l0, c0, c1) = (c0, c1, 0);
let (c0, c1, c2) = muladd(a[0], b[1], c0, c1, c2);
let (c0, c1, c2) = muladd(a[1], b[0], c0, c1, c2);
let (l1, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = muladd(a[0], b[2], c0, c1, c2);
let (c0, c1, c2) = muladd(a[1], b[1], c0, c1, c2);
let (c0, c1, c2) = muladd(a[2], b[0], c0, c1, c2);
let (l2, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = muladd(a[0], b[3], c0, c1, c2);
let (c0, c1, c2) = muladd(a[1], b[2], c0, c1, c2);
let (c0, c1, c2) = muladd(a[2], b[1], c0, c1, c2);
let (c0, c1, c2) = muladd(a[3], b[0], c0, c1, c2);
let (l3, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = muladd(a[1], b[3], c0, c1, c2);
let (c0, c1, c2) = muladd(a[2], b[2], c0, c1, c2);
let (c0, c1, c2) = muladd(a[3], b[1], c0, c1, c2);
let (l4, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = muladd(a[2], b[3], c0, c1, c2);
let (c0, c1, c2) = muladd(a[3], b[2], c0, c1, c2);
let (l5, c0, c1, _c2) = (c0, c1, c2, 0);
let (c0, c1) = muladd_fast(a[3], b[3], c0, c1);
let (l6, c0, _c1) = (c0, c1, 0);
let l7 = c0;
Self(U512::from_words([l0, l1, l2, l3, l4, l5, l6, l7]))
}
pub(crate) fn mul_shift_vartime(a: &Scalar, b: &Scalar, shift: usize) -> Scalar {
debug_assert!(shift >= 256);
let l = Self::mul_wide(a, b).0.to_words();
let shiftlimbs = shift >> 6;
let shiftlow = shift & 0x3F;
let shifthigh = 64 - shiftlow;
let r0 = if shift < 512 {
let lo = l[shiftlimbs] >> shiftlow;
let hi = if shift < 448 && shiftlow != 0 {
l[1 + shiftlimbs] << shifthigh
} else {
0
};
hi | lo
} else {
0
};
let r1 = if shift < 448 {
let lo = l[1 + shiftlimbs] >> shiftlow;
let hi = if shift < 384 && shiftlow != 0 {
l[2 + shiftlimbs] << shifthigh
} else {
0
};
hi | lo
} else {
0
};
let r2 = if shift < 384 {
let lo = l[2 + shiftlimbs] >> shiftlow;
let hi = if shift < 320 && shiftlow != 0 {
l[3 + shiftlimbs] << shifthigh
} else {
0
};
hi | lo
} else {
0
};
let r3 = if shift < 320 {
l[3 + shiftlimbs] >> shiftlow
} else {
0
};
let res = Scalar(U256::from_words([r0, r1, r2, r3]));
let c = (l[(shift - 1) >> 6] >> ((shift - 1) & 0x3f)) & 1;
Scalar::conditional_select(&res, &res.add(&Scalar::ONE), Choice::from(c as u8))
}
fn reduce_impl(&self, modulus_minus_one: bool) -> Scalar {
let neg_modulus0 = if modulus_minus_one {
NEG_MODULUS[0] + 1
} else {
NEG_MODULUS[0]
};
let modulus = if modulus_minus_one {
ORDER.wrapping_sub(&U256::ONE)
} else {
ORDER
};
let w = self.0.to_words();
let n0 = w[4];
let n1 = w[5];
let n2 = w[6];
let n3 = w[7];
let c0 = w[0];
let c1 = 0;
let c2 = 0;
let (c0, c1) = muladd_fast(n0, neg_modulus0, c0, c1);
let (m0, c0, c1) = (c0, c1, 0);
let (c0, c1) = sumadd_fast(w[1], c0, c1);
let (c0, c1, c2) = muladd(n1, neg_modulus0, c0, c1, c2);
let (c0, c1, c2) = muladd(n0, NEG_MODULUS[1], c0, c1, c2);
let (m1, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = sumadd(w[2], c0, c1, c2);
let (c0, c1, c2) = muladd(n2, neg_modulus0, c0, c1, c2);
let (c0, c1, c2) = muladd(n1, NEG_MODULUS[1], c0, c1, c2);
let (c0, c1, c2) = sumadd(n0, c0, c1, c2);
let (m2, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = sumadd(w[3], c0, c1, c2);
let (c0, c1, c2) = muladd(n3, neg_modulus0, c0, c1, c2);
let (c0, c1, c2) = muladd(n2, NEG_MODULUS[1], c0, c1, c2);
let (c0, c1, c2) = sumadd(n1, c0, c1, c2);
let (m3, c0, c1, c2) = (c0, c1, c2, 0);
let (c0, c1, c2) = muladd(n3, NEG_MODULUS[1], c0, c1, c2);
let (c0, c1, c2) = sumadd(n2, c0, c1, c2);
let (m4, c0, c1, _c2) = (c0, c1, c2, 0);
let (c0, c1) = sumadd_fast(n3, c0, c1);
let (m5, c0, _c1) = (c0, c1, 0);
debug_assert!(c0 <= 1);
let m6 = c0;
let c0 = m0;
let c1 = 0;
let c2 = 0;
let (c0, c1) = muladd_fast(m4, neg_modulus0, c0, c1);
let (p0, c0, c1) = (c0, c1, 0);
let (c0, c1) = sumadd_fast(m1, c0, c1);
let (c0, c1, c2) = muladd(m5, neg_modulus0, c0, c1, c2);
let (c0, c1, c2) = muladd(m4, NEG_MODULUS[1], c0, c1, c2);
let (p1, c0, c1) = (c0, c1, 0);
let (c0, c1, c2) = sumadd(m2, c0, c1, c2);
let (c0, c1, c2) = muladd(m6, neg_modulus0, c0, c1, c2);
let (c0, c1, c2) = muladd(m5, NEG_MODULUS[1], c0, c1, c2);
let (c0, c1, c2) = sumadd(m4, c0, c1, c2);
let (p2, c0, c1, _c2) = (c0, c1, c2, 0);
let (c0, c1) = sumadd_fast(m3, c0, c1);
let (c0, c1) = muladd_fast(m6, NEG_MODULUS[1], c0, c1);
let (c0, c1) = sumadd_fast(m5, c0, c1);
let (p3, c0, _c1) = (c0, c1, 0);
let p4 = c0 + m6;
debug_assert!(p4 <= 2);
let mut c = (p0 as u128) + (neg_modulus0 as u128) * (p4 as u128);
let r0 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
c >>= 64;
c += (p1 as u128) + (NEG_MODULUS[1] as u128) * (p4 as u128);
let r1 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
c >>= 64;
c += (p2 as u128) + (p4 as u128);
let r2 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
c >>= 64;
c += p3 as u128;
let r3 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
c >>= 64;
let r = U256::from([r0, r1, r2, r3]);
let (r2, underflow) = r.sbb(&modulus, Limb::ZERO);
let high_bit = Choice::from(c as u8);
let underflow = Choice::from((underflow.0 >> 63) as u8);
Scalar(U256::conditional_select(&r, &r2, !underflow | high_bit))
}
#[inline(always)] pub(super) fn reduce(&self) -> Scalar {
self.reduce_impl(false)
}
pub(super) fn reduce_nonzero(&self) -> Scalar {
self.reduce_impl(true) + Scalar::ONE
}
}
fn sumadd(a: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
let (new_c0, carry0) = c0.overflowing_add(a);
let (new_c1, carry1) = c1.overflowing_add(carry0 as u64);
let new_c2 = c2 + (carry1 as u64);
(new_c0, new_c1, new_c2)
}
fn sumadd_fast(a: u64, c0: u64, c1: u64) -> (u64, u64) {
let (new_c0, carry0) = c0.overflowing_add(a);
let new_c1 = c1 + (carry0 as u64);
(new_c0, new_c1)
}
fn muladd(a: u64, b: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
let t = (a as u128) * (b as u128);
let th = (t >> 64) as u64; let tl = t as u64;
let (new_c0, carry0) = c0.overflowing_add(tl);
let new_th = th.wrapping_add(carry0 as u64); let (new_c1, carry1) = c1.overflowing_add(new_th);
let new_c2 = c2 + (carry1 as u64);
(new_c0, new_c1, new_c2)
}
fn muladd_fast(a: u64, b: u64, c0: u64, c1: u64) -> (u64, u64) {
let t = (a as u128) * (b as u128);
let th = (t >> 64) as u64; let tl = t as u64;
let (new_c0, carry0) = c0.overflowing_add(tl);
let new_th = th.wrapping_add(carry0 as u64); let new_c1 = c1 + new_th;
(new_c0, new_c1)
}