1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
use crate::{Limb, Uint, Word};

use super::mul::{mul_montgomery_form, square_montgomery_form};

#[cfg(feature = "alloc")]
use alloc::vec::Vec;

const WINDOW: usize = 4;
const WINDOW_MASK: Word = (1 << WINDOW) - 1;

/// Performs modular exponentiation using Montgomery's ladder.
/// `exponent_bits` represents the number of bits to take into account for the exponent.
///
/// NOTE: this value is leaked in the time pattern.
pub const fn pow_montgomery_form<const LIMBS: usize, const RHS_LIMBS: usize>(
    x: &Uint<LIMBS>,
    exponent: &Uint<RHS_LIMBS>,
    exponent_bits: usize,
    modulus: &Uint<LIMBS>,
    r: &Uint<LIMBS>,
    mod_neg_inv: Limb,
) -> Uint<LIMBS> {
    multi_exponentiate_montgomery_form_array(
        &[(*x, *exponent)],
        exponent_bits,
        modulus,
        r,
        mod_neg_inv,
    )
}

pub const fn multi_exponentiate_montgomery_form_array<
    const LIMBS: usize,
    const RHS_LIMBS: usize,
    const N: usize,
>(
    bases_and_exponents: &[(Uint<LIMBS>, Uint<RHS_LIMBS>); N],
    exponent_bits: usize,
    modulus: &Uint<LIMBS>,
    r: &Uint<LIMBS>,
    mod_neg_inv: Limb,
) -> Uint<LIMBS> {
    if exponent_bits == 0 {
        return *r; // 1 in Montgomery form
    }

    let mut powers_and_exponents =
        [([Uint::<LIMBS>::ZERO; 1 << WINDOW], Uint::<RHS_LIMBS>::ZERO); N];

    let mut i = 0;
    while i < N {
        let (base, exponent) = bases_and_exponents[i];
        powers_and_exponents[i] = (compute_powers(&base, modulus, r, mod_neg_inv), exponent);
        i += 1;
    }

    multi_exponentiate_montgomery_form_internal(
        &powers_and_exponents,
        exponent_bits,
        modulus,
        r,
        mod_neg_inv,
    )
}

/// Performs modular multi-exponentiation using Montgomery's ladder.
/// `exponent_bits` represents the number of bits to take into account for the exponent.
///
/// See: Straus, E. G. Problems and solutions: Addition chains of vectors. American Mathematical Monthly 71 (1964), 806–808.
///
/// NOTE: this value is leaked in the time pattern.
#[cfg(feature = "alloc")]
pub fn multi_exponentiate_montgomery_form_slice<const LIMBS: usize, const RHS_LIMBS: usize>(
    bases_and_exponents: &[(Uint<LIMBS>, Uint<RHS_LIMBS>)],
    exponent_bits: usize,
    modulus: &Uint<LIMBS>,
    r: &Uint<LIMBS>,
    mod_neg_inv: Limb,
) -> Uint<LIMBS> {
    if exponent_bits == 0 {
        return *r; // 1 in Montgomery form
    }

    let powers_and_exponents: Vec<([Uint<LIMBS>; 1 << WINDOW], Uint<RHS_LIMBS>)> =
        bases_and_exponents
            .iter()
            .map(|(base, exponent)| (compute_powers(base, modulus, r, mod_neg_inv), *exponent))
            .collect();

    multi_exponentiate_montgomery_form_internal(
        powers_and_exponents.as_slice(),
        exponent_bits,
        modulus,
        r,
        mod_neg_inv,
    )
}

const fn compute_powers<const LIMBS: usize>(
    x: &Uint<LIMBS>,
    modulus: &Uint<LIMBS>,
    r: &Uint<LIMBS>,
    mod_neg_inv: Limb,
) -> [Uint<LIMBS>; 1 << WINDOW] {
    // powers[i] contains x^i
    let mut powers = [*r; 1 << WINDOW];
    powers[1] = *x;

    let mut i = 2;
    while i < powers.len() {
        powers[i] = mul_montgomery_form(&powers[i - 1], x, modulus, mod_neg_inv);
        i += 1;
    }

    powers
}

const fn multi_exponentiate_montgomery_form_internal<const LIMBS: usize, const RHS_LIMBS: usize>(
    powers_and_exponents: &[([Uint<LIMBS>; 1 << WINDOW], Uint<RHS_LIMBS>)],
    exponent_bits: usize,
    modulus: &Uint<LIMBS>,
    r: &Uint<LIMBS>,
    mod_neg_inv: Limb,
) -> Uint<LIMBS> {
    let starting_limb = (exponent_bits - 1) / Limb::BITS;
    let starting_bit_in_limb = (exponent_bits - 1) % Limb::BITS;
    let starting_window = starting_bit_in_limb / WINDOW;
    let starting_window_mask = (1 << (starting_bit_in_limb % WINDOW + 1)) - 1;

    let mut z = *r; // 1 in Montgomery form

    let mut limb_num = starting_limb + 1;
    while limb_num > 0 {
        limb_num -= 1;

        let mut window_num = if limb_num == starting_limb {
            starting_window + 1
        } else {
            Limb::BITS / WINDOW
        };
        while window_num > 0 {
            window_num -= 1;

            if limb_num != starting_limb || window_num != starting_window {
                let mut i = 0;
                while i < WINDOW {
                    i += 1;
                    z = square_montgomery_form(&z, modulus, mod_neg_inv);
                }
            }

            let mut i = 0;
            while i < powers_and_exponents.len() {
                let (powers, exponent) = powers_and_exponents[i];
                let w = exponent.as_limbs()[limb_num].0;
                let mut idx = (w >> (window_num * WINDOW)) & WINDOW_MASK;

                if limb_num == starting_limb && window_num == starting_window {
                    idx &= starting_window_mask;
                }

                // Constant-time lookup in the array of powers
                let mut power = powers[0];
                let mut j = 1;
                while j < 1 << WINDOW {
                    let choice = Limb::ct_eq(Limb(j as Word), Limb(idx));
                    power = Uint::<LIMBS>::ct_select(&power, &powers[j], choice);
                    j += 1;
                }

                z = mul_montgomery_form(&z, &power, modulus, mod_neg_inv);
                i += 1;
            }
        }
    }

    z
}