blake2b_simd/
avx2.rs

1#[cfg(target_arch = "x86")]
2use core::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use core::arch::x86_64::*;
5
6use crate::guts::{
7    assemble_count, count_high, count_low, final_block, flag_word, input_debug_asserts, Finalize,
8    Job, LastNode, Stride,
9};
10use crate::{Count, Word, BLOCKBYTES, IV, SIGMA};
11use arrayref::{array_refs, mut_array_refs};
12use core::cmp;
13use core::mem;
14
15pub const DEGREE: usize = 4;
16
17#[inline(always)]
18unsafe fn loadu(src: *const [Word; DEGREE]) -> __m256i {
19    // This is an unaligned load, so the pointer cast is allowed.
20    _mm256_loadu_si256(src as *const __m256i)
21}
22
23#[inline(always)]
24unsafe fn storeu(src: __m256i, dest: *mut [Word; DEGREE]) {
25    // This is an unaligned store, so the pointer cast is allowed.
26    _mm256_storeu_si256(dest as *mut __m256i, src)
27}
28
29#[inline(always)]
30unsafe fn loadu_128(mem_addr: &[u8; 16]) -> __m128i {
31    _mm_loadu_si128(mem_addr.as_ptr() as *const __m128i)
32}
33
34#[inline(always)]
35unsafe fn add(a: __m256i, b: __m256i) -> __m256i {
36    _mm256_add_epi64(a, b)
37}
38
39#[inline(always)]
40unsafe fn eq(a: __m256i, b: __m256i) -> __m256i {
41    _mm256_cmpeq_epi64(a, b)
42}
43
44#[inline(always)]
45unsafe fn and(a: __m256i, b: __m256i) -> __m256i {
46    _mm256_and_si256(a, b)
47}
48
49#[inline(always)]
50unsafe fn negate_and(a: __m256i, b: __m256i) -> __m256i {
51    // Note that "and not" implies the reverse of the actual arg order.
52    _mm256_andnot_si256(a, b)
53}
54
55#[inline(always)]
56unsafe fn xor(a: __m256i, b: __m256i) -> __m256i {
57    _mm256_xor_si256(a, b)
58}
59
60#[inline(always)]
61unsafe fn set1(x: u64) -> __m256i {
62    _mm256_set1_epi64x(x as i64)
63}
64
65#[inline(always)]
66unsafe fn set4(a: u64, b: u64, c: u64, d: u64) -> __m256i {
67    _mm256_setr_epi64x(a as i64, b as i64, c as i64, d as i64)
68}
69
70// Adapted from https://github.com/rust-lang-nursery/stdsimd/pull/479.
71macro_rules! _MM_SHUFFLE {
72    ($z:expr, $y:expr, $x:expr, $w:expr) => {
73        ($z << 6) | ($y << 4) | ($x << 2) | $w
74    };
75}
76
77// These rotations are the "simple version". For the "complicated version", see
78// https://github.com/sneves/blake2-avx2/blob/b3723921f668df09ece52dcd225a36d4a4eea1d9/blake2b-common.h#L43-L46.
79// For a discussion of the tradeoffs, see
80// https://github.com/sneves/blake2-avx2/pull/5. In short:
81// - Due to an LLVM bug (https://bugs.llvm.org/show_bug.cgi?id=44379), this
82//   version performs better on recent x86 chips.
83// - LLVM is able to optimize this version to AVX-512 rotation instructions
84//   when those are enabled.
85
86#[inline(always)]
87unsafe fn rot32(x: __m256i) -> __m256i {
88    _mm256_or_si256(_mm256_srli_epi64(x, 32), _mm256_slli_epi64(x, 64 - 32))
89}
90
91#[inline(always)]
92unsafe fn rot24(x: __m256i) -> __m256i {
93    _mm256_or_si256(_mm256_srli_epi64(x, 24), _mm256_slli_epi64(x, 64 - 24))
94}
95
96#[inline(always)]
97unsafe fn rot16(x: __m256i) -> __m256i {
98    _mm256_or_si256(_mm256_srli_epi64(x, 16), _mm256_slli_epi64(x, 64 - 16))
99}
100
101#[inline(always)]
102unsafe fn rot63(x: __m256i) -> __m256i {
103    _mm256_or_si256(_mm256_srli_epi64(x, 63), _mm256_slli_epi64(x, 64 - 63))
104}
105
106#[inline(always)]
107unsafe fn g1(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i, m: &mut __m256i) {
108    *a = add(*a, *m);
109    *a = add(*a, *b);
110    *d = xor(*d, *a);
111    *d = rot32(*d);
112    *c = add(*c, *d);
113    *b = xor(*b, *c);
114    *b = rot24(*b);
115}
116
117#[inline(always)]
118unsafe fn g2(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i, m: &mut __m256i) {
119    *a = add(*a, *m);
120    *a = add(*a, *b);
121    *d = xor(*d, *a);
122    *d = rot16(*d);
123    *c = add(*c, *d);
124    *b = xor(*b, *c);
125    *b = rot63(*b);
126}
127
128// Note the optimization here of leaving b as the unrotated row, rather than a.
129// All the message loads below are adjusted to compensate for this. See
130// discussion at https://github.com/sneves/blake2-avx2/pull/4
131#[inline(always)]
132unsafe fn diagonalize(a: &mut __m256i, _b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
133    *a = _mm256_permute4x64_epi64(*a, _MM_SHUFFLE!(2, 1, 0, 3));
134    *d = _mm256_permute4x64_epi64(*d, _MM_SHUFFLE!(1, 0, 3, 2));
135    *c = _mm256_permute4x64_epi64(*c, _MM_SHUFFLE!(0, 3, 2, 1));
136}
137
138#[inline(always)]
139unsafe fn undiagonalize(a: &mut __m256i, _b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
140    *a = _mm256_permute4x64_epi64(*a, _MM_SHUFFLE!(0, 3, 2, 1));
141    *d = _mm256_permute4x64_epi64(*d, _MM_SHUFFLE!(1, 0, 3, 2));
142    *c = _mm256_permute4x64_epi64(*c, _MM_SHUFFLE!(2, 1, 0, 3));
143}
144
145#[inline(always)]
146unsafe fn compress_block(
147    block: &[u8; BLOCKBYTES],
148    words: &mut [Word; 8],
149    count: Count,
150    last_block: Word,
151    last_node: Word,
152) {
153    let (words_low, words_high) = mut_array_refs!(words, DEGREE, DEGREE);
154    let (iv_low, iv_high) = array_refs!(&IV, DEGREE, DEGREE);
155    let mut a = loadu(words_low);
156    let mut b = loadu(words_high);
157    let mut c = loadu(iv_low);
158    let flags = set4(count_low(count), count_high(count), last_block, last_node);
159    let mut d = xor(loadu(iv_high), flags);
160
161    let msg_chunks = array_refs!(block, 16, 16, 16, 16, 16, 16, 16, 16);
162    let m0 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.0));
163    let m1 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.1));
164    let m2 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.2));
165    let m3 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.3));
166    let m4 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.4));
167    let m5 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.5));
168    let m6 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.6));
169    let m7 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.7));
170
171    let iv0 = a;
172    let iv1 = b;
173    let mut t0;
174    let mut t1;
175    let mut b0;
176
177    // round 1
178    t0 = _mm256_unpacklo_epi64(m0, m1);
179    t1 = _mm256_unpacklo_epi64(m2, m3);
180    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
181    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
182    t0 = _mm256_unpackhi_epi64(m0, m1);
183    t1 = _mm256_unpackhi_epi64(m2, m3);
184    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
185    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
186    diagonalize(&mut a, &mut b, &mut c, &mut d);
187    t0 = _mm256_unpacklo_epi64(m7, m4);
188    t1 = _mm256_unpacklo_epi64(m5, m6);
189    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
190    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
191    t0 = _mm256_unpackhi_epi64(m7, m4);
192    t1 = _mm256_unpackhi_epi64(m5, m6);
193    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
194    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
195    undiagonalize(&mut a, &mut b, &mut c, &mut d);
196
197    // round 2
198    t0 = _mm256_unpacklo_epi64(m7, m2);
199    t1 = _mm256_unpackhi_epi64(m4, m6);
200    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
201    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
202    t0 = _mm256_unpacklo_epi64(m5, m4);
203    t1 = _mm256_alignr_epi8(m3, m7, 8);
204    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
205    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
206    diagonalize(&mut a, &mut b, &mut c, &mut d);
207    t0 = _mm256_unpackhi_epi64(m2, m0);
208    t1 = _mm256_blend_epi32(m5, m0, 0x33);
209    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
210    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
211    t0 = _mm256_alignr_epi8(m6, m1, 8);
212    t1 = _mm256_blend_epi32(m3, m1, 0x33);
213    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
214    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
215    undiagonalize(&mut a, &mut b, &mut c, &mut d);
216
217    // round 3
218    t0 = _mm256_alignr_epi8(m6, m5, 8);
219    t1 = _mm256_unpackhi_epi64(m2, m7);
220    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
221    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
222    t0 = _mm256_unpacklo_epi64(m4, m0);
223    t1 = _mm256_blend_epi32(m6, m1, 0x33);
224    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
225    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
226    diagonalize(&mut a, &mut b, &mut c, &mut d);
227    t0 = _mm256_alignr_epi8(m5, m4, 8);
228    t1 = _mm256_unpackhi_epi64(m1, m3);
229    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
230    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
231    t0 = _mm256_unpacklo_epi64(m2, m7);
232    t1 = _mm256_blend_epi32(m0, m3, 0x33);
233    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
234    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
235    undiagonalize(&mut a, &mut b, &mut c, &mut d);
236
237    // round 4
238    t0 = _mm256_unpackhi_epi64(m3, m1);
239    t1 = _mm256_unpackhi_epi64(m6, m5);
240    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
241    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
242    t0 = _mm256_unpackhi_epi64(m4, m0);
243    t1 = _mm256_unpacklo_epi64(m6, m7);
244    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
245    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
246    diagonalize(&mut a, &mut b, &mut c, &mut d);
247    t0 = _mm256_alignr_epi8(m1, m7, 8);
248    t1 = _mm256_shuffle_epi32(m2, _MM_SHUFFLE!(1, 0, 3, 2));
249    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
250    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
251    t0 = _mm256_unpacklo_epi64(m4, m3);
252    t1 = _mm256_unpacklo_epi64(m5, m0);
253    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
254    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
255    undiagonalize(&mut a, &mut b, &mut c, &mut d);
256
257    // round 5
258    t0 = _mm256_unpackhi_epi64(m4, m2);
259    t1 = _mm256_unpacklo_epi64(m1, m5);
260    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
261    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
262    t0 = _mm256_blend_epi32(m3, m0, 0x33);
263    t1 = _mm256_blend_epi32(m7, m2, 0x33);
264    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
265    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
266    diagonalize(&mut a, &mut b, &mut c, &mut d);
267    t0 = _mm256_alignr_epi8(m7, m1, 8);
268    t1 = _mm256_alignr_epi8(m3, m5, 8);
269    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
270    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
271    t0 = _mm256_unpackhi_epi64(m6, m0);
272    t1 = _mm256_unpacklo_epi64(m6, m4);
273    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
274    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
275    undiagonalize(&mut a, &mut b, &mut c, &mut d);
276
277    // round 6
278    t0 = _mm256_unpacklo_epi64(m1, m3);
279    t1 = _mm256_unpacklo_epi64(m0, m4);
280    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
281    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
282    t0 = _mm256_unpacklo_epi64(m6, m5);
283    t1 = _mm256_unpackhi_epi64(m5, m1);
284    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
285    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
286    diagonalize(&mut a, &mut b, &mut c, &mut d);
287    t0 = _mm256_alignr_epi8(m2, m0, 8);
288    t1 = _mm256_unpackhi_epi64(m3, m7);
289    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
290    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
291    t0 = _mm256_unpackhi_epi64(m4, m6);
292    t1 = _mm256_alignr_epi8(m7, m2, 8);
293    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
294    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
295    undiagonalize(&mut a, &mut b, &mut c, &mut d);
296
297    // round 7
298    t0 = _mm256_blend_epi32(m0, m6, 0x33);
299    t1 = _mm256_unpacklo_epi64(m7, m2);
300    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
301    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
302    t0 = _mm256_unpackhi_epi64(m2, m7);
303    t1 = _mm256_alignr_epi8(m5, m6, 8);
304    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
305    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
306    diagonalize(&mut a, &mut b, &mut c, &mut d);
307    t0 = _mm256_unpacklo_epi64(m4, m0);
308    t1 = _mm256_blend_epi32(m4, m3, 0x33);
309    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
310    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
311    t0 = _mm256_unpackhi_epi64(m5, m3);
312    t1 = _mm256_shuffle_epi32(m1, _MM_SHUFFLE!(1, 0, 3, 2));
313    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
314    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
315    undiagonalize(&mut a, &mut b, &mut c, &mut d);
316
317    // round 8
318    t0 = _mm256_unpackhi_epi64(m6, m3);
319    t1 = _mm256_blend_epi32(m1, m6, 0x33);
320    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
321    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
322    t0 = _mm256_alignr_epi8(m7, m5, 8);
323    t1 = _mm256_unpackhi_epi64(m0, m4);
324    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
325    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
326    diagonalize(&mut a, &mut b, &mut c, &mut d);
327    t0 = _mm256_blend_epi32(m2, m1, 0x33);
328    t1 = _mm256_alignr_epi8(m4, m7, 8);
329    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
330    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
331    t0 = _mm256_unpacklo_epi64(m5, m0);
332    t1 = _mm256_unpacklo_epi64(m2, m3);
333    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
334    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
335    undiagonalize(&mut a, &mut b, &mut c, &mut d);
336
337    // round 9
338    t0 = _mm256_unpacklo_epi64(m3, m7);
339    t1 = _mm256_alignr_epi8(m0, m5, 8);
340    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
341    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
342    t0 = _mm256_unpackhi_epi64(m7, m4);
343    t1 = _mm256_alignr_epi8(m4, m1, 8);
344    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
345    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
346    diagonalize(&mut a, &mut b, &mut c, &mut d);
347    t0 = _mm256_unpacklo_epi64(m5, m6);
348    t1 = _mm256_unpackhi_epi64(m6, m0);
349    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
350    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
351    t0 = _mm256_alignr_epi8(m1, m2, 8);
352    t1 = _mm256_alignr_epi8(m2, m3, 8);
353    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
354    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
355    undiagonalize(&mut a, &mut b, &mut c, &mut d);
356
357    // round 10
358    t0 = _mm256_unpacklo_epi64(m5, m4);
359    t1 = _mm256_unpackhi_epi64(m3, m0);
360    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
361    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
362    t0 = _mm256_unpacklo_epi64(m1, m2);
363    t1 = _mm256_blend_epi32(m2, m3, 0x33);
364    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
365    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
366    diagonalize(&mut a, &mut b, &mut c, &mut d);
367    t0 = _mm256_unpackhi_epi64(m6, m7);
368    t1 = _mm256_unpackhi_epi64(m4, m1);
369    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
370    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
371    t0 = _mm256_blend_epi32(m5, m0, 0x33);
372    t1 = _mm256_unpacklo_epi64(m7, m6);
373    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
374    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
375    undiagonalize(&mut a, &mut b, &mut c, &mut d);
376
377    // round 11
378    t0 = _mm256_unpacklo_epi64(m0, m1);
379    t1 = _mm256_unpacklo_epi64(m2, m3);
380    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
381    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
382    t0 = _mm256_unpackhi_epi64(m0, m1);
383    t1 = _mm256_unpackhi_epi64(m2, m3);
384    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
385    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
386    diagonalize(&mut a, &mut b, &mut c, &mut d);
387    t0 = _mm256_unpacklo_epi64(m7, m4);
388    t1 = _mm256_unpacklo_epi64(m5, m6);
389    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
390    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
391    t0 = _mm256_unpackhi_epi64(m7, m4);
392    t1 = _mm256_unpackhi_epi64(m5, m6);
393    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
394    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
395    undiagonalize(&mut a, &mut b, &mut c, &mut d);
396
397    // round 12
398    t0 = _mm256_unpacklo_epi64(m7, m2);
399    t1 = _mm256_unpackhi_epi64(m4, m6);
400    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
401    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
402    t0 = _mm256_unpacklo_epi64(m5, m4);
403    t1 = _mm256_alignr_epi8(m3, m7, 8);
404    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
405    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
406    diagonalize(&mut a, &mut b, &mut c, &mut d);
407    t0 = _mm256_unpackhi_epi64(m2, m0);
408    t1 = _mm256_blend_epi32(m5, m0, 0x33);
409    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
410    g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
411    t0 = _mm256_alignr_epi8(m6, m1, 8);
412    t1 = _mm256_blend_epi32(m3, m1, 0x33);
413    b0 = _mm256_blend_epi32(t0, t1, 0xF0);
414    g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
415    undiagonalize(&mut a, &mut b, &mut c, &mut d);
416
417    a = xor(a, c);
418    b = xor(b, d);
419    a = xor(a, iv0);
420    b = xor(b, iv1);
421
422    storeu(a, words_low);
423    storeu(b, words_high);
424}
425
426#[target_feature(enable = "avx2")]
427pub unsafe fn compress1_loop(
428    input: &[u8],
429    words: &mut [Word; 8],
430    mut count: Count,
431    last_node: LastNode,
432    finalize: Finalize,
433    stride: Stride,
434) {
435    input_debug_asserts(input, finalize);
436
437    let mut local_words = *words;
438
439    let mut fin_offset = input.len().saturating_sub(1);
440    fin_offset -= fin_offset % stride.padded_blockbytes();
441    let mut buf = [0; BLOCKBYTES];
442    let (fin_block, fin_len, _) = final_block(input, fin_offset, &mut buf, stride);
443    let fin_last_block = flag_word(finalize.yes());
444    let fin_last_node = flag_word(finalize.yes() && last_node.yes());
445
446    let mut offset = 0;
447    loop {
448        let block;
449        let count_delta;
450        let last_block;
451        let last_node;
452        if offset == fin_offset {
453            block = fin_block;
454            count_delta = fin_len;
455            last_block = fin_last_block;
456            last_node = fin_last_node;
457        } else {
458            // This unsafe cast avoids bounds checks. There's guaranteed to be
459            // enough input because `offset < fin_offset`.
460            block = &*(input.as_ptr().add(offset) as *const [u8; BLOCKBYTES]);
461            count_delta = BLOCKBYTES;
462            last_block = flag_word(false);
463            last_node = flag_word(false);
464        };
465
466        count = count.wrapping_add(count_delta as Count);
467        compress_block(block, &mut local_words, count, last_block, last_node);
468
469        // Check for termination before bumping the offset, to avoid overflow.
470        if offset == fin_offset {
471            break;
472        }
473
474        offset += stride.padded_blockbytes();
475    }
476
477    *words = local_words;
478}
479
480// Performance note: Factoring out a G function here doesn't hurt performance,
481// unlike in the case of BLAKE2s where it hurts substantially. In fact, on my
482// machine, it helps a tiny bit. But the difference it tiny, so I'm going to
483// stick to the approach used by https://github.com/sneves/blake2-avx2
484// until/unless I can be sure the (tiny) improvement is consistent across
485// different Intel microarchitectures. Smaller code size is nice, but a
486// divergence between the BLAKE2b and BLAKE2s implementations is less nice.
487#[inline(always)]
488unsafe fn round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize) {
489    v[0] = add(v[0], m[SIGMA[r][0] as usize]);
490    v[1] = add(v[1], m[SIGMA[r][2] as usize]);
491    v[2] = add(v[2], m[SIGMA[r][4] as usize]);
492    v[3] = add(v[3], m[SIGMA[r][6] as usize]);
493    v[0] = add(v[0], v[4]);
494    v[1] = add(v[1], v[5]);
495    v[2] = add(v[2], v[6]);
496    v[3] = add(v[3], v[7]);
497    v[12] = xor(v[12], v[0]);
498    v[13] = xor(v[13], v[1]);
499    v[14] = xor(v[14], v[2]);
500    v[15] = xor(v[15], v[3]);
501    v[12] = rot32(v[12]);
502    v[13] = rot32(v[13]);
503    v[14] = rot32(v[14]);
504    v[15] = rot32(v[15]);
505    v[8] = add(v[8], v[12]);
506    v[9] = add(v[9], v[13]);
507    v[10] = add(v[10], v[14]);
508    v[11] = add(v[11], v[15]);
509    v[4] = xor(v[4], v[8]);
510    v[5] = xor(v[5], v[9]);
511    v[6] = xor(v[6], v[10]);
512    v[7] = xor(v[7], v[11]);
513    v[4] = rot24(v[4]);
514    v[5] = rot24(v[5]);
515    v[6] = rot24(v[6]);
516    v[7] = rot24(v[7]);
517    v[0] = add(v[0], m[SIGMA[r][1] as usize]);
518    v[1] = add(v[1], m[SIGMA[r][3] as usize]);
519    v[2] = add(v[2], m[SIGMA[r][5] as usize]);
520    v[3] = add(v[3], m[SIGMA[r][7] as usize]);
521    v[0] = add(v[0], v[4]);
522    v[1] = add(v[1], v[5]);
523    v[2] = add(v[2], v[6]);
524    v[3] = add(v[3], v[7]);
525    v[12] = xor(v[12], v[0]);
526    v[13] = xor(v[13], v[1]);
527    v[14] = xor(v[14], v[2]);
528    v[15] = xor(v[15], v[3]);
529    v[12] = rot16(v[12]);
530    v[13] = rot16(v[13]);
531    v[14] = rot16(v[14]);
532    v[15] = rot16(v[15]);
533    v[8] = add(v[8], v[12]);
534    v[9] = add(v[9], v[13]);
535    v[10] = add(v[10], v[14]);
536    v[11] = add(v[11], v[15]);
537    v[4] = xor(v[4], v[8]);
538    v[5] = xor(v[5], v[9]);
539    v[6] = xor(v[6], v[10]);
540    v[7] = xor(v[7], v[11]);
541    v[4] = rot63(v[4]);
542    v[5] = rot63(v[5]);
543    v[6] = rot63(v[6]);
544    v[7] = rot63(v[7]);
545
546    v[0] = add(v[0], m[SIGMA[r][8] as usize]);
547    v[1] = add(v[1], m[SIGMA[r][10] as usize]);
548    v[2] = add(v[2], m[SIGMA[r][12] as usize]);
549    v[3] = add(v[3], m[SIGMA[r][14] as usize]);
550    v[0] = add(v[0], v[5]);
551    v[1] = add(v[1], v[6]);
552    v[2] = add(v[2], v[7]);
553    v[3] = add(v[3], v[4]);
554    v[15] = xor(v[15], v[0]);
555    v[12] = xor(v[12], v[1]);
556    v[13] = xor(v[13], v[2]);
557    v[14] = xor(v[14], v[3]);
558    v[15] = rot32(v[15]);
559    v[12] = rot32(v[12]);
560    v[13] = rot32(v[13]);
561    v[14] = rot32(v[14]);
562    v[10] = add(v[10], v[15]);
563    v[11] = add(v[11], v[12]);
564    v[8] = add(v[8], v[13]);
565    v[9] = add(v[9], v[14]);
566    v[5] = xor(v[5], v[10]);
567    v[6] = xor(v[6], v[11]);
568    v[7] = xor(v[7], v[8]);
569    v[4] = xor(v[4], v[9]);
570    v[5] = rot24(v[5]);
571    v[6] = rot24(v[6]);
572    v[7] = rot24(v[7]);
573    v[4] = rot24(v[4]);
574    v[0] = add(v[0], m[SIGMA[r][9] as usize]);
575    v[1] = add(v[1], m[SIGMA[r][11] as usize]);
576    v[2] = add(v[2], m[SIGMA[r][13] as usize]);
577    v[3] = add(v[3], m[SIGMA[r][15] as usize]);
578    v[0] = add(v[0], v[5]);
579    v[1] = add(v[1], v[6]);
580    v[2] = add(v[2], v[7]);
581    v[3] = add(v[3], v[4]);
582    v[15] = xor(v[15], v[0]);
583    v[12] = xor(v[12], v[1]);
584    v[13] = xor(v[13], v[2]);
585    v[14] = xor(v[14], v[3]);
586    v[15] = rot16(v[15]);
587    v[12] = rot16(v[12]);
588    v[13] = rot16(v[13]);
589    v[14] = rot16(v[14]);
590    v[10] = add(v[10], v[15]);
591    v[11] = add(v[11], v[12]);
592    v[8] = add(v[8], v[13]);
593    v[9] = add(v[9], v[14]);
594    v[5] = xor(v[5], v[10]);
595    v[6] = xor(v[6], v[11]);
596    v[7] = xor(v[7], v[8]);
597    v[4] = xor(v[4], v[9]);
598    v[5] = rot63(v[5]);
599    v[6] = rot63(v[6]);
600    v[7] = rot63(v[7]);
601    v[4] = rot63(v[4]);
602}
603
604// We'd rather make this a regular function with #[inline(always)], but for
605// some reason that blows up compile times by about 10 seconds, at least in
606// some cases (BLAKE2b avx2.rs). This macro seems to get the same performance
607// result, without the compile time issue.
608macro_rules! compress4_transposed {
609    (
610        $h_vecs:expr,
611        $msg_vecs:expr,
612        $count_low:expr,
613        $count_high:expr,
614        $lastblock:expr,
615        $lastnode:expr,
616    ) => {
617        let h_vecs: &mut [__m256i; 8] = $h_vecs;
618        let msg_vecs: &[__m256i; 16] = $msg_vecs;
619        let count_low: __m256i = $count_low;
620        let count_high: __m256i = $count_high;
621        let lastblock: __m256i = $lastblock;
622        let lastnode: __m256i = $lastnode;
623
624        let mut v = [
625            h_vecs[0],
626            h_vecs[1],
627            h_vecs[2],
628            h_vecs[3],
629            h_vecs[4],
630            h_vecs[5],
631            h_vecs[6],
632            h_vecs[7],
633            set1(IV[0]),
634            set1(IV[1]),
635            set1(IV[2]),
636            set1(IV[3]),
637            xor(set1(IV[4]), count_low),
638            xor(set1(IV[5]), count_high),
639            xor(set1(IV[6]), lastblock),
640            xor(set1(IV[7]), lastnode),
641        ];
642
643        round(&mut v, &msg_vecs, 0);
644        round(&mut v, &msg_vecs, 1);
645        round(&mut v, &msg_vecs, 2);
646        round(&mut v, &msg_vecs, 3);
647        round(&mut v, &msg_vecs, 4);
648        round(&mut v, &msg_vecs, 5);
649        round(&mut v, &msg_vecs, 6);
650        round(&mut v, &msg_vecs, 7);
651        round(&mut v, &msg_vecs, 8);
652        round(&mut v, &msg_vecs, 9);
653        round(&mut v, &msg_vecs, 10);
654        round(&mut v, &msg_vecs, 11);
655
656        h_vecs[0] = xor(xor(h_vecs[0], v[0]), v[8]);
657        h_vecs[1] = xor(xor(h_vecs[1], v[1]), v[9]);
658        h_vecs[2] = xor(xor(h_vecs[2], v[2]), v[10]);
659        h_vecs[3] = xor(xor(h_vecs[3], v[3]), v[11]);
660        h_vecs[4] = xor(xor(h_vecs[4], v[4]), v[12]);
661        h_vecs[5] = xor(xor(h_vecs[5], v[5]), v[13]);
662        h_vecs[6] = xor(xor(h_vecs[6], v[6]), v[14]);
663        h_vecs[7] = xor(xor(h_vecs[7], v[7]), v[15]);
664    };
665}
666
667#[inline(always)]
668unsafe fn interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i) {
669    (
670        _mm256_permute2x128_si256(a, b, 0x20),
671        _mm256_permute2x128_si256(a, b, 0x31),
672    )
673}
674
675// There are several ways to do a transposition. We could do it naively, with 8 separate
676// _mm256_set_epi64x instructions, referencing each of the 64 words explicitly. Or we could copy
677// the vecs into contiguous storage and then use gather instructions. This third approach is to use
678// a series of unpack instructions to interleave the vectors. In my benchmarks, interleaving is the
679// fastest approach. To test this, run `cargo +nightly bench --bench libtest load_4` in the
680// https://github.com/oconnor663/bao_experiments repo.
681#[inline(always)]
682unsafe fn transpose_vecs(
683    vec_a: __m256i,
684    vec_b: __m256i,
685    vec_c: __m256i,
686    vec_d: __m256i,
687) -> [__m256i; DEGREE] {
688    // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is 11/33.
689    let ab_02 = _mm256_unpacklo_epi64(vec_a, vec_b);
690    let ab_13 = _mm256_unpackhi_epi64(vec_a, vec_b);
691    let cd_02 = _mm256_unpacklo_epi64(vec_c, vec_d);
692    let cd_13 = _mm256_unpackhi_epi64(vec_c, vec_d);
693
694    // Interleave 128-bit lanes.
695    let (abcd_0, abcd_2) = interleave128(ab_02, cd_02);
696    let (abcd_1, abcd_3) = interleave128(ab_13, cd_13);
697
698    [abcd_0, abcd_1, abcd_2, abcd_3]
699}
700
701#[inline(always)]
702unsafe fn transpose_state_vecs(jobs: &[Job; DEGREE]) -> [__m256i; 8] {
703    // Load all the state words into transposed vectors, where the first vector
704    // has the first word of each state, etc. Transposing once at the beginning
705    // and once at the end is more efficient that repeating it for each block.
706    let words0 = array_refs!(&jobs[0].words, DEGREE, DEGREE);
707    let words1 = array_refs!(&jobs[1].words, DEGREE, DEGREE);
708    let words2 = array_refs!(&jobs[2].words, DEGREE, DEGREE);
709    let words3 = array_refs!(&jobs[3].words, DEGREE, DEGREE);
710    let [h0, h1, h2, h3] = transpose_vecs(
711        loadu(words0.0),
712        loadu(words1.0),
713        loadu(words2.0),
714        loadu(words3.0),
715    );
716    let [h4, h5, h6, h7] = transpose_vecs(
717        loadu(words0.1),
718        loadu(words1.1),
719        loadu(words2.1),
720        loadu(words3.1),
721    );
722    [h0, h1, h2, h3, h4, h5, h6, h7]
723}
724
725#[inline(always)]
726unsafe fn untranspose_state_vecs(h_vecs: &[__m256i; 8], jobs: &mut [Job; DEGREE]) {
727    // Un-transpose the updated state vectors back into the caller's arrays.
728    let [job0, job1, job2, job3] = jobs;
729    let words0 = mut_array_refs!(&mut job0.words, DEGREE, DEGREE);
730    let words1 = mut_array_refs!(&mut job1.words, DEGREE, DEGREE);
731    let words2 = mut_array_refs!(&mut job2.words, DEGREE, DEGREE);
732    let words3 = mut_array_refs!(&mut job3.words, DEGREE, DEGREE);
733    let out = transpose_vecs(h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3]);
734    storeu(out[0], words0.0);
735    storeu(out[1], words1.0);
736    storeu(out[2], words2.0);
737    storeu(out[3], words3.0);
738    let out = transpose_vecs(h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7]);
739    storeu(out[0], words0.1);
740    storeu(out[1], words1.1);
741    storeu(out[2], words2.1);
742    storeu(out[3], words3.1);
743}
744
745#[inline(always)]
746unsafe fn transpose_msg_vecs(blocks: [*const [u8; BLOCKBYTES]; DEGREE]) -> [__m256i; 16] {
747    // These input arrays have no particular alignment, so we use unaligned
748    // loads to read from them.
749    let block0 = blocks[0] as *const [Word; DEGREE];
750    let block1 = blocks[1] as *const [Word; DEGREE];
751    let block2 = blocks[2] as *const [Word; DEGREE];
752    let block3 = blocks[3] as *const [Word; DEGREE];
753    let [m0, m1, m2, m3] = transpose_vecs(
754        loadu(block0.add(0)),
755        loadu(block1.add(0)),
756        loadu(block2.add(0)),
757        loadu(block3.add(0)),
758    );
759    let [m4, m5, m6, m7] = transpose_vecs(
760        loadu(block0.add(1)),
761        loadu(block1.add(1)),
762        loadu(block2.add(1)),
763        loadu(block3.add(1)),
764    );
765    let [m8, m9, m10, m11] = transpose_vecs(
766        loadu(block0.add(2)),
767        loadu(block1.add(2)),
768        loadu(block2.add(2)),
769        loadu(block3.add(2)),
770    );
771    let [m12, m13, m14, m15] = transpose_vecs(
772        loadu(block0.add(3)),
773        loadu(block1.add(3)),
774        loadu(block2.add(3)),
775        loadu(block3.add(3)),
776    );
777    [
778        m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15,
779    ]
780}
781
782#[inline(always)]
783unsafe fn load_counts(jobs: &[Job; DEGREE]) -> (__m256i, __m256i) {
784    (
785        set4(
786            count_low(jobs[0].count),
787            count_low(jobs[1].count),
788            count_low(jobs[2].count),
789            count_low(jobs[3].count),
790        ),
791        set4(
792            count_high(jobs[0].count),
793            count_high(jobs[1].count),
794            count_high(jobs[2].count),
795            count_high(jobs[3].count),
796        ),
797    )
798}
799
800#[inline(always)]
801unsafe fn store_counts(jobs: &mut [Job; DEGREE], low: __m256i, high: __m256i) {
802    let low_ints: [Word; DEGREE] = mem::transmute(low);
803    let high_ints: [Word; DEGREE] = mem::transmute(high);
804    for i in 0..DEGREE {
805        jobs[i].count = assemble_count(low_ints[i], high_ints[i]);
806    }
807}
808
809#[inline(always)]
810unsafe fn add_to_counts(lo: &mut __m256i, hi: &mut __m256i, delta: __m256i) {
811    // If the low counts reach zero, that means they wrapped, unless the delta
812    // was also zero.
813    *lo = add(*lo, delta);
814    let lo_reached_zero = eq(*lo, set1(0));
815    let delta_was_zero = eq(delta, set1(0));
816    let hi_inc = and(set1(1), negate_and(delta_was_zero, lo_reached_zero));
817    *hi = add(*hi, hi_inc);
818}
819
820#[inline(always)]
821unsafe fn flags_vec(flags: [bool; DEGREE]) -> __m256i {
822    set4(
823        flag_word(flags[0]),
824        flag_word(flags[1]),
825        flag_word(flags[2]),
826        flag_word(flags[3]),
827    )
828}
829
830#[target_feature(enable = "avx2")]
831pub unsafe fn compress4_loop(jobs: &mut [Job; DEGREE], finalize: Finalize, stride: Stride) {
832    // If we're not finalizing, there can't be a partial block at the end.
833    for job in jobs.iter() {
834        input_debug_asserts(job.input, finalize);
835    }
836
837    let msg_ptrs = [
838        jobs[0].input.as_ptr(),
839        jobs[1].input.as_ptr(),
840        jobs[2].input.as_ptr(),
841        jobs[3].input.as_ptr(),
842    ];
843    let mut h_vecs = transpose_state_vecs(&jobs);
844    let (mut counts_lo, mut counts_hi) = load_counts(&jobs);
845
846    // Prepare the final blocks (note, which could be empty if the input is
847    // empty). Do all this before entering the main loop.
848    let min_len = jobs.iter().map(|job| job.input.len()).min().unwrap();
849    let mut fin_offset = min_len.saturating_sub(1);
850    fin_offset -= fin_offset % stride.padded_blockbytes();
851    // Performance note, making these buffers mem::uninitialized() seems to
852    // cause problems in the optimizer.
853    let mut buf0: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
854    let mut buf1: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
855    let mut buf2: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
856    let mut buf3: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
857    let (block0, len0, finalize0) = final_block(jobs[0].input, fin_offset, &mut buf0, stride);
858    let (block1, len1, finalize1) = final_block(jobs[1].input, fin_offset, &mut buf1, stride);
859    let (block2, len2, finalize2) = final_block(jobs[2].input, fin_offset, &mut buf2, stride);
860    let (block3, len3, finalize3) = final_block(jobs[3].input, fin_offset, &mut buf3, stride);
861    let fin_blocks: [*const [u8; BLOCKBYTES]; DEGREE] = [block0, block1, block2, block3];
862    let fin_counts_delta = set4(len0 as Word, len1 as Word, len2 as Word, len3 as Word);
863    let fin_last_block;
864    let fin_last_node;
865    if finalize.yes() {
866        fin_last_block = flags_vec([finalize0, finalize1, finalize2, finalize3]);
867        fin_last_node = flags_vec([
868            finalize0 && jobs[0].last_node.yes(),
869            finalize1 && jobs[1].last_node.yes(),
870            finalize2 && jobs[2].last_node.yes(),
871            finalize3 && jobs[3].last_node.yes(),
872        ]);
873    } else {
874        fin_last_block = set1(0);
875        fin_last_node = set1(0);
876    }
877
878    // The main loop.
879    let mut offset = 0;
880    loop {
881        let blocks;
882        let counts_delta;
883        let last_block;
884        let last_node;
885        if offset == fin_offset {
886            blocks = fin_blocks;
887            counts_delta = fin_counts_delta;
888            last_block = fin_last_block;
889            last_node = fin_last_node;
890        } else {
891            blocks = [
892                msg_ptrs[0].add(offset) as *const [u8; BLOCKBYTES],
893                msg_ptrs[1].add(offset) as *const [u8; BLOCKBYTES],
894                msg_ptrs[2].add(offset) as *const [u8; BLOCKBYTES],
895                msg_ptrs[3].add(offset) as *const [u8; BLOCKBYTES],
896            ];
897            counts_delta = set1(BLOCKBYTES as Word);
898            last_block = set1(0);
899            last_node = set1(0);
900        };
901
902        let m_vecs = transpose_msg_vecs(blocks);
903        add_to_counts(&mut counts_lo, &mut counts_hi, counts_delta);
904        compress4_transposed!(
905            &mut h_vecs,
906            &m_vecs,
907            counts_lo,
908            counts_hi,
909            last_block,
910            last_node,
911        );
912
913        // Check for termination before bumping the offset, to avoid overflow.
914        if offset == fin_offset {
915            break;
916        }
917
918        offset += stride.padded_blockbytes();
919    }
920
921    // Write out the results.
922    untranspose_state_vecs(&h_vecs, &mut *jobs);
923    store_counts(&mut *jobs, counts_lo, counts_hi);
924    let max_consumed = offset.saturating_add(stride.padded_blockbytes());
925    for job in jobs.iter_mut() {
926        let consumed = cmp::min(max_consumed, job.input.len());
927        job.input = &job.input[consumed..];
928    }
929}