blake2b_simd/
sse41.rs

1#[cfg(target_arch = "x86")]
2use core::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use core::arch::x86_64::*;
5
6use crate::guts::{
7    assemble_count, count_high, count_low, final_block, flag_word, input_debug_asserts, Finalize,
8    Job, Stride,
9};
10use crate::{Word, BLOCKBYTES, IV, SIGMA};
11use arrayref::{array_refs, mut_array_refs};
12use core::cmp;
13use core::mem;
14
15pub const DEGREE: usize = 2;
16
17#[inline(always)]
18unsafe fn loadu(src: *const [Word; DEGREE]) -> __m128i {
19    // This is an unaligned load, so the pointer cast is allowed.
20    _mm_loadu_si128(src as *const __m128i)
21}
22
23#[inline(always)]
24unsafe fn storeu(src: __m128i, dest: *mut [Word; DEGREE]) {
25    // This is an unaligned store, so the pointer cast is allowed.
26    _mm_storeu_si128(dest as *mut __m128i, src)
27}
28
29#[inline(always)]
30unsafe fn add(a: __m128i, b: __m128i) -> __m128i {
31    _mm_add_epi64(a, b)
32}
33
34#[inline(always)]
35unsafe fn eq(a: __m128i, b: __m128i) -> __m128i {
36    _mm_cmpeq_epi64(a, b)
37}
38
39#[inline(always)]
40unsafe fn and(a: __m128i, b: __m128i) -> __m128i {
41    _mm_and_si128(a, b)
42}
43
44#[inline(always)]
45unsafe fn negate_and(a: __m128i, b: __m128i) -> __m128i {
46    // Note that "and not" implies the reverse of the actual arg order.
47    _mm_andnot_si128(a, b)
48}
49
50#[inline(always)]
51unsafe fn xor(a: __m128i, b: __m128i) -> __m128i {
52    _mm_xor_si128(a, b)
53}
54
55#[inline(always)]
56unsafe fn set1(x: u64) -> __m128i {
57    _mm_set1_epi64x(x as i64)
58}
59
60#[inline(always)]
61unsafe fn set2(a: u64, b: u64) -> __m128i {
62    // There's no _mm_setr_epi64x, so note the arg order is backwards.
63    _mm_set_epi64x(b as i64, a as i64)
64}
65
66// Adapted from https://github.com/rust-lang-nursery/stdsimd/pull/479.
67macro_rules! _MM_SHUFFLE {
68    ($z:expr, $y:expr, $x:expr, $w:expr) => {
69        ($z << 6) | ($y << 4) | ($x << 2) | $w
70    };
71}
72
73// These rotations are the "simple version". For the "complicated version", see
74// https://github.com/sneves/blake2-avx2/blob/b3723921f668df09ece52dcd225a36d4a4eea1d9/blake2b-common.h#L43-L46.
75// For a discussion of the tradeoffs, see
76// https://github.com/sneves/blake2-avx2/pull/5. In short:
77// - Due to an LLVM bug (https://bugs.llvm.org/show_bug.cgi?id=44379), this
78//   version performs better on recent x86 chips.
79// - LLVM is able to optimize this version to AVX-512 rotation instructions
80//   when those are enabled.
81
82#[inline(always)]
83unsafe fn rot32(x: __m128i) -> __m128i {
84    _mm_or_si128(_mm_srli_epi64(x, 32), _mm_slli_epi64(x, 64 - 32))
85}
86
87#[inline(always)]
88unsafe fn rot24(x: __m128i) -> __m128i {
89    _mm_or_si128(_mm_srli_epi64(x, 24), _mm_slli_epi64(x, 64 - 24))
90}
91
92#[inline(always)]
93unsafe fn rot16(x: __m128i) -> __m128i {
94    _mm_or_si128(_mm_srli_epi64(x, 16), _mm_slli_epi64(x, 64 - 16))
95}
96
97#[inline(always)]
98unsafe fn rot63(x: __m128i) -> __m128i {
99    _mm_or_si128(_mm_srli_epi64(x, 63), _mm_slli_epi64(x, 64 - 63))
100}
101
102#[inline(always)]
103unsafe fn round(v: &mut [__m128i; 16], m: &[__m128i; 16], r: usize) {
104    v[0] = add(v[0], m[SIGMA[r][0] as usize]);
105    v[1] = add(v[1], m[SIGMA[r][2] as usize]);
106    v[2] = add(v[2], m[SIGMA[r][4] as usize]);
107    v[3] = add(v[3], m[SIGMA[r][6] as usize]);
108    v[0] = add(v[0], v[4]);
109    v[1] = add(v[1], v[5]);
110    v[2] = add(v[2], v[6]);
111    v[3] = add(v[3], v[7]);
112    v[12] = xor(v[12], v[0]);
113    v[13] = xor(v[13], v[1]);
114    v[14] = xor(v[14], v[2]);
115    v[15] = xor(v[15], v[3]);
116    v[12] = rot32(v[12]);
117    v[13] = rot32(v[13]);
118    v[14] = rot32(v[14]);
119    v[15] = rot32(v[15]);
120    v[8] = add(v[8], v[12]);
121    v[9] = add(v[9], v[13]);
122    v[10] = add(v[10], v[14]);
123    v[11] = add(v[11], v[15]);
124    v[4] = xor(v[4], v[8]);
125    v[5] = xor(v[5], v[9]);
126    v[6] = xor(v[6], v[10]);
127    v[7] = xor(v[7], v[11]);
128    v[4] = rot24(v[4]);
129    v[5] = rot24(v[5]);
130    v[6] = rot24(v[6]);
131    v[7] = rot24(v[7]);
132    v[0] = add(v[0], m[SIGMA[r][1] as usize]);
133    v[1] = add(v[1], m[SIGMA[r][3] as usize]);
134    v[2] = add(v[2], m[SIGMA[r][5] as usize]);
135    v[3] = add(v[3], m[SIGMA[r][7] as usize]);
136    v[0] = add(v[0], v[4]);
137    v[1] = add(v[1], v[5]);
138    v[2] = add(v[2], v[6]);
139    v[3] = add(v[3], v[7]);
140    v[12] = xor(v[12], v[0]);
141    v[13] = xor(v[13], v[1]);
142    v[14] = xor(v[14], v[2]);
143    v[15] = xor(v[15], v[3]);
144    v[12] = rot16(v[12]);
145    v[13] = rot16(v[13]);
146    v[14] = rot16(v[14]);
147    v[15] = rot16(v[15]);
148    v[8] = add(v[8], v[12]);
149    v[9] = add(v[9], v[13]);
150    v[10] = add(v[10], v[14]);
151    v[11] = add(v[11], v[15]);
152    v[4] = xor(v[4], v[8]);
153    v[5] = xor(v[5], v[9]);
154    v[6] = xor(v[6], v[10]);
155    v[7] = xor(v[7], v[11]);
156    v[4] = rot63(v[4]);
157    v[5] = rot63(v[5]);
158    v[6] = rot63(v[6]);
159    v[7] = rot63(v[7]);
160
161    v[0] = add(v[0], m[SIGMA[r][8] as usize]);
162    v[1] = add(v[1], m[SIGMA[r][10] as usize]);
163    v[2] = add(v[2], m[SIGMA[r][12] as usize]);
164    v[3] = add(v[3], m[SIGMA[r][14] as usize]);
165    v[0] = add(v[0], v[5]);
166    v[1] = add(v[1], v[6]);
167    v[2] = add(v[2], v[7]);
168    v[3] = add(v[3], v[4]);
169    v[15] = xor(v[15], v[0]);
170    v[12] = xor(v[12], v[1]);
171    v[13] = xor(v[13], v[2]);
172    v[14] = xor(v[14], v[3]);
173    v[15] = rot32(v[15]);
174    v[12] = rot32(v[12]);
175    v[13] = rot32(v[13]);
176    v[14] = rot32(v[14]);
177    v[10] = add(v[10], v[15]);
178    v[11] = add(v[11], v[12]);
179    v[8] = add(v[8], v[13]);
180    v[9] = add(v[9], v[14]);
181    v[5] = xor(v[5], v[10]);
182    v[6] = xor(v[6], v[11]);
183    v[7] = xor(v[7], v[8]);
184    v[4] = xor(v[4], v[9]);
185    v[5] = rot24(v[5]);
186    v[6] = rot24(v[6]);
187    v[7] = rot24(v[7]);
188    v[4] = rot24(v[4]);
189    v[0] = add(v[0], m[SIGMA[r][9] as usize]);
190    v[1] = add(v[1], m[SIGMA[r][11] as usize]);
191    v[2] = add(v[2], m[SIGMA[r][13] as usize]);
192    v[3] = add(v[3], m[SIGMA[r][15] as usize]);
193    v[0] = add(v[0], v[5]);
194    v[1] = add(v[1], v[6]);
195    v[2] = add(v[2], v[7]);
196    v[3] = add(v[3], v[4]);
197    v[15] = xor(v[15], v[0]);
198    v[12] = xor(v[12], v[1]);
199    v[13] = xor(v[13], v[2]);
200    v[14] = xor(v[14], v[3]);
201    v[15] = rot16(v[15]);
202    v[12] = rot16(v[12]);
203    v[13] = rot16(v[13]);
204    v[14] = rot16(v[14]);
205    v[10] = add(v[10], v[15]);
206    v[11] = add(v[11], v[12]);
207    v[8] = add(v[8], v[13]);
208    v[9] = add(v[9], v[14]);
209    v[5] = xor(v[5], v[10]);
210    v[6] = xor(v[6], v[11]);
211    v[7] = xor(v[7], v[8]);
212    v[4] = xor(v[4], v[9]);
213    v[5] = rot63(v[5]);
214    v[6] = rot63(v[6]);
215    v[7] = rot63(v[7]);
216    v[4] = rot63(v[4]);
217}
218
219// We'd rather make this a regular function with #[inline(always)], but for
220// some reason that blows up compile times by about 10 seconds, at least in
221// some cases (BLAKE2b avx2.rs). This macro seems to get the same performance
222// result, without the compile time issue.
223macro_rules! compress2_transposed {
224    (
225        $h_vecs:expr,
226        $msg_vecs:expr,
227        $count_low:expr,
228        $count_high:expr,
229        $lastblock:expr,
230        $lastnode:expr,
231    ) => {
232        let h_vecs: &mut [__m128i; 8] = $h_vecs;
233        let msg_vecs: &[__m128i; 16] = $msg_vecs;
234        let count_low: __m128i = $count_low;
235        let count_high: __m128i = $count_high;
236        let lastblock: __m128i = $lastblock;
237        let lastnode: __m128i = $lastnode;
238        let mut v = [
239            h_vecs[0],
240            h_vecs[1],
241            h_vecs[2],
242            h_vecs[3],
243            h_vecs[4],
244            h_vecs[5],
245            h_vecs[6],
246            h_vecs[7],
247            set1(IV[0]),
248            set1(IV[1]),
249            set1(IV[2]),
250            set1(IV[3]),
251            xor(set1(IV[4]), count_low),
252            xor(set1(IV[5]), count_high),
253            xor(set1(IV[6]), lastblock),
254            xor(set1(IV[7]), lastnode),
255        ];
256
257        round(&mut v, &msg_vecs, 0);
258        round(&mut v, &msg_vecs, 1);
259        round(&mut v, &msg_vecs, 2);
260        round(&mut v, &msg_vecs, 3);
261        round(&mut v, &msg_vecs, 4);
262        round(&mut v, &msg_vecs, 5);
263        round(&mut v, &msg_vecs, 6);
264        round(&mut v, &msg_vecs, 7);
265        round(&mut v, &msg_vecs, 8);
266        round(&mut v, &msg_vecs, 9);
267        round(&mut v, &msg_vecs, 10);
268        round(&mut v, &msg_vecs, 11);
269
270        h_vecs[0] = xor(xor(h_vecs[0], v[0]), v[8]);
271        h_vecs[1] = xor(xor(h_vecs[1], v[1]), v[9]);
272        h_vecs[2] = xor(xor(h_vecs[2], v[2]), v[10]);
273        h_vecs[3] = xor(xor(h_vecs[3], v[3]), v[11]);
274        h_vecs[4] = xor(xor(h_vecs[4], v[4]), v[12]);
275        h_vecs[5] = xor(xor(h_vecs[5], v[5]), v[13]);
276        h_vecs[6] = xor(xor(h_vecs[6], v[6]), v[14]);
277        h_vecs[7] = xor(xor(h_vecs[7], v[7]), v[15]);
278    };
279}
280
281#[inline(always)]
282unsafe fn transpose_vecs(a: __m128i, b: __m128i) -> [__m128i; DEGREE] {
283    let a_words: [Word; DEGREE] = mem::transmute(a);
284    let b_words: [Word; DEGREE] = mem::transmute(b);
285    [set2(a_words[0], b_words[0]), set2(a_words[1], b_words[1])]
286}
287
288#[inline(always)]
289unsafe fn transpose_state_vecs(jobs: &[Job; DEGREE]) -> [__m128i; 8] {
290    // Load all the state words into transposed vectors, where the first vector
291    // has the first word of each state, etc. Transposing once at the beginning
292    // and once at the end is more efficient that repeating it for each block.
293    let words0 = array_refs!(&jobs[0].words, DEGREE, DEGREE, DEGREE, DEGREE);
294    let words1 = array_refs!(&jobs[1].words, DEGREE, DEGREE, DEGREE, DEGREE);
295    let [h0, h1] = transpose_vecs(loadu(words0.0), loadu(words1.0));
296    let [h2, h3] = transpose_vecs(loadu(words0.1), loadu(words1.1));
297    let [h4, h5] = transpose_vecs(loadu(words0.2), loadu(words1.2));
298    let [h6, h7] = transpose_vecs(loadu(words0.3), loadu(words1.3));
299    [h0, h1, h2, h3, h4, h5, h6, h7]
300}
301
302#[inline(always)]
303unsafe fn untranspose_state_vecs(h_vecs: &[__m128i; 8], jobs: &mut [Job; DEGREE]) {
304    // Un-transpose the updated state vectors back into the caller's arrays.
305    let [job0, job1] = jobs;
306    let words0 = mut_array_refs!(&mut job0.words, DEGREE, DEGREE, DEGREE, DEGREE);
307    let words1 = mut_array_refs!(&mut job1.words, DEGREE, DEGREE, DEGREE, DEGREE);
308
309    let out = transpose_vecs(h_vecs[0], h_vecs[1]);
310    storeu(out[0], words0.0);
311    storeu(out[1], words1.0);
312    let out = transpose_vecs(h_vecs[2], h_vecs[3]);
313    storeu(out[0], words0.1);
314    storeu(out[1], words1.1);
315    let out = transpose_vecs(h_vecs[4], h_vecs[5]);
316    storeu(out[0], words0.2);
317    storeu(out[1], words1.2);
318    let out = transpose_vecs(h_vecs[6], h_vecs[7]);
319    storeu(out[0], words0.3);
320    storeu(out[1], words1.3);
321}
322
323#[inline(always)]
324unsafe fn transpose_msg_vecs(blocks: [*const [u8; BLOCKBYTES]; DEGREE]) -> [__m128i; 16] {
325    // These input arrays have no particular alignment, so we use unaligned
326    // loads to read from them.
327    let block0 = blocks[0] as *const [Word; DEGREE];
328    let block1 = blocks[1] as *const [Word; DEGREE];
329    let [m0, m1] = transpose_vecs(loadu(block0.add(0)), loadu(block1.add(0)));
330    let [m2, m3] = transpose_vecs(loadu(block0.add(1)), loadu(block1.add(1)));
331    let [m4, m5] = transpose_vecs(loadu(block0.add(2)), loadu(block1.add(2)));
332    let [m6, m7] = transpose_vecs(loadu(block0.add(3)), loadu(block1.add(3)));
333    let [m8, m9] = transpose_vecs(loadu(block0.add(4)), loadu(block1.add(4)));
334    let [m10, m11] = transpose_vecs(loadu(block0.add(5)), loadu(block1.add(5)));
335    let [m12, m13] = transpose_vecs(loadu(block0.add(6)), loadu(block1.add(6)));
336    let [m14, m15] = transpose_vecs(loadu(block0.add(7)), loadu(block1.add(7)));
337    [
338        m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15,
339    ]
340}
341
342#[inline(always)]
343unsafe fn load_counts(jobs: &[Job; DEGREE]) -> (__m128i, __m128i) {
344    (
345        set2(count_low(jobs[0].count), count_low(jobs[1].count)),
346        set2(count_high(jobs[0].count), count_high(jobs[1].count)),
347    )
348}
349
350#[inline(always)]
351unsafe fn store_counts(jobs: &mut [Job; DEGREE], low: __m128i, high: __m128i) {
352    let low_ints: [Word; DEGREE] = mem::transmute(low);
353    let high_ints: [Word; DEGREE] = mem::transmute(high);
354    for i in 0..DEGREE {
355        jobs[i].count = assemble_count(low_ints[i], high_ints[i]);
356    }
357}
358
359#[inline(always)]
360unsafe fn add_to_counts(lo: &mut __m128i, hi: &mut __m128i, delta: __m128i) {
361    // If the low counts reach zero, that means they wrapped, unless the delta
362    // was also zero.
363    *lo = add(*lo, delta);
364    let lo_reached_zero = eq(*lo, set1(0));
365    let delta_was_zero = eq(delta, set1(0));
366    let hi_inc = and(set1(1), negate_and(delta_was_zero, lo_reached_zero));
367    *hi = add(*hi, hi_inc);
368}
369
370#[inline(always)]
371unsafe fn flags_vec(flags: [bool; DEGREE]) -> __m128i {
372    set2(flag_word(flags[0]), flag_word(flags[1]))
373}
374
375#[target_feature(enable = "sse4.1")]
376pub unsafe fn compress2_loop(jobs: &mut [Job; DEGREE], finalize: Finalize, stride: Stride) {
377    // If we're not finalizing, there can't be a partial block at the end.
378    for job in jobs.iter() {
379        input_debug_asserts(job.input, finalize);
380    }
381
382    let msg_ptrs = [jobs[0].input.as_ptr(), jobs[1].input.as_ptr()];
383    let mut h_vecs = transpose_state_vecs(&jobs);
384    let (mut counts_lo, mut counts_hi) = load_counts(&jobs);
385
386    // Prepare the final blocks (note, which could be empty if the input is
387    // empty). Do all this before entering the main loop.
388    let min_len = jobs.iter().map(|job| job.input.len()).min().unwrap();
389    let mut fin_offset = min_len.saturating_sub(1);
390    fin_offset -= fin_offset % stride.padded_blockbytes();
391    // Performance note, making these buffers mem::uninitialized() seems to
392    // cause problems in the optimizer.
393    let mut buf0: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
394    let mut buf1: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
395    let (block0, len0, finalize0) = final_block(jobs[0].input, fin_offset, &mut buf0, stride);
396    let (block1, len1, finalize1) = final_block(jobs[1].input, fin_offset, &mut buf1, stride);
397    let fin_blocks: [*const [u8; BLOCKBYTES]; DEGREE] = [block0, block1];
398    let fin_counts_delta = set2(len0 as Word, len1 as Word);
399    let fin_last_block;
400    let fin_last_node;
401    if finalize.yes() {
402        fin_last_block = flags_vec([finalize0, finalize1]);
403        fin_last_node = flags_vec([
404            finalize0 && jobs[0].last_node.yes(),
405            finalize1 && jobs[1].last_node.yes(),
406        ]);
407    } else {
408        fin_last_block = set1(0);
409        fin_last_node = set1(0);
410    }
411
412    // The main loop.
413    let mut offset = 0;
414    loop {
415        let blocks;
416        let counts_delta;
417        let last_block;
418        let last_node;
419        if offset == fin_offset {
420            blocks = fin_blocks;
421            counts_delta = fin_counts_delta;
422            last_block = fin_last_block;
423            last_node = fin_last_node;
424        } else {
425            blocks = [
426                msg_ptrs[0].add(offset) as *const [u8; BLOCKBYTES],
427                msg_ptrs[1].add(offset) as *const [u8; BLOCKBYTES],
428            ];
429            counts_delta = set1(BLOCKBYTES as Word);
430            last_block = set1(0);
431            last_node = set1(0);
432        };
433
434        let m_vecs = transpose_msg_vecs(blocks);
435        add_to_counts(&mut counts_lo, &mut counts_hi, counts_delta);
436        compress2_transposed!(
437            &mut h_vecs,
438            &m_vecs,
439            counts_lo,
440            counts_hi,
441            last_block,
442            last_node,
443        );
444
445        // Check for termination before bumping the offset, to avoid overflow.
446        if offset == fin_offset {
447            break;
448        }
449
450        offset += stride.padded_blockbytes();
451    }
452
453    // Write out the results.
454    untranspose_state_vecs(&h_vecs, &mut *jobs);
455    store_counts(&mut *jobs, counts_lo, counts_hi);
456    let max_consumed = offset.saturating_add(stride.padded_blockbytes());
457    for job in jobs.iter_mut() {
458        let consumed = cmp::min(max_consumed, job.input.len());
459        job.input = &job.input[consumed..];
460    }
461}