chacha20/backends/
avx2.rs

1//! AVX2 backend.
2
3#![allow(unsafe_op_in_unsafe_fn, reason = "needs triage")]
4#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
5#![allow(clippy::cast_possible_wrap, reason = "needs triage")]
6#![allow(clippy::cast_sign_loss, reason = "needs triage")]
7#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
8
9use crate::{Rounds, Variant};
10use core::marker::PhantomData;
11
12#[cfg(feature = "rng")]
13use crate::ChaChaCore;
14
15#[cfg(feature = "cipher")]
16use crate::{STATE_WORDS, chacha::Block};
17
18#[cfg(feature = "cipher")]
19use cipher::{
20    BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
21    consts::{U4, U64},
22};
23
24#[cfg(target_arch = "x86")]
25use core::arch::x86::*;
26#[cfg(target_arch = "x86_64")]
27use core::arch::x86_64::*;
28
29/// Number of blocks processed in parallel.
30const PAR_BLOCKS: usize = 4;
31/// Number of `__m256i` to store parallel blocks.
32const N: usize = PAR_BLOCKS / 2;
33
34#[inline]
35#[target_feature(enable = "avx2")]
36#[cfg(feature = "cipher")]
37#[cfg_attr(chacha20_backend = "avx512", expect(unused))]
38pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
39where
40    R: Rounds,
41    F: StreamCipherClosure<BlockSize = U64>,
42    V: Variant,
43{
44    let state_ptr = state.as_ptr().cast::<__m128i>();
45    let v = [
46        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
47        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
48        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
49    ];
50    let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
51    c = match size_of::<V::Counter>() {
52        4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)),
53        8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)),
54        _ => unreachable!(),
55    };
56    let mut ctr = [c; N];
57    for i in 0..N {
58        ctr[i] = c;
59        c = match size_of::<V::Counter>() {
60            4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)),
61            8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)),
62            _ => unreachable!(),
63        };
64    }
65    let mut backend = Backend::<R, V> {
66        v,
67        ctr,
68        _pd: PhantomData,
69    };
70
71    f.call(&mut backend);
72
73    state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
74    match size_of::<V::Counter>() {
75        4 => {}
76        8 => state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32,
77        _ => unreachable!(),
78    }
79}
80
81#[inline]
82#[target_feature(enable = "avx2")]
83#[cfg(feature = "rng")]
84pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
85where
86    R: Rounds,
87    V: Variant,
88{
89    let state_ptr = core.state.as_ptr().cast::<__m128i>();
90    let v = [
91        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
92        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
93        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
94    ];
95    let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
96    c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
97    let mut ctr = [c; N];
98    for i in 0..N {
99        ctr[i] = c;
100        c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
101    }
102    let mut backend = Backend::<R, V> {
103        v,
104        ctr,
105        _pd: PhantomData,
106    };
107
108    backend.rng_gen_par_ks_blocks(buffer);
109
110    core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
111    core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
112}
113
114struct Backend<R: Rounds, V: Variant> {
115    v: [__m256i; 3],
116    ctr: [__m256i; N],
117    _pd: PhantomData<(R, V)>,
118}
119
120#[cfg(feature = "cipher")]
121impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
122    type BlockSize = U64;
123}
124
125#[cfg(feature = "cipher")]
126impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
127    type ParBlocksSize = U4;
128}
129
130#[cfg(feature = "cipher")]
131impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
132    #[inline(always)]
133    fn gen_ks_block(&mut self, block: &mut Block) {
134        unsafe {
135            let res = rounds::<R>(&self.v, &self.ctr);
136            for c in self.ctr.iter_mut() {
137                *c = match size_of::<V::Counter>() {
138                    4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1)),
139                    8 => _mm256_add_epi64(*c, _mm256_set_epi64x(0, 1, 0, 1)),
140                    _ => unreachable!(),
141                };
142            }
143
144            let res0: [__m128i; 8] = core::mem::transmute(res[0]);
145
146            let block_ptr = block.as_mut_ptr().cast::<__m128i>();
147            for i in 0..4 {
148                _mm_storeu_si128(block_ptr.add(i), res0[2 * i]);
149            }
150        }
151    }
152
153    #[inline(always)]
154    fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) {
155        unsafe {
156            let vs = rounds::<R>(&self.v, &self.ctr);
157
158            let pb = PAR_BLOCKS as i32;
159            for c in self.ctr.iter_mut() {
160                *c = match size_of::<V::Counter>() {
161                    4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb)),
162                    8 => {
163                        _mm256_add_epi64(*c, _mm256_set_epi64x(0, i64::from(pb), 0, i64::from(pb)))
164                    }
165                    _ => unreachable!(),
166                }
167            }
168
169            let mut block_ptr = blocks.as_mut_ptr().cast::<__m128i>();
170            for v in vs {
171                let t: [__m128i; 8] = core::mem::transmute(v);
172                for i in 0..4 {
173                    _mm_storeu_si128(block_ptr.add(i), t[2 * i]);
174                    _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
175                }
176                block_ptr = block_ptr.add(8);
177            }
178        }
179    }
180}
181
182#[cfg(feature = "rng")]
183impl<R: Rounds, V: Variant> Backend<R, V> {
184    #[inline(always)]
185    fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) {
186        unsafe {
187            let vs = rounds::<R>(&self.v, &self.ctr);
188
189            let pb = PAR_BLOCKS as i32;
190            for c in self.ctr.iter_mut() {
191                *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, i64::from(pb), 0, i64::from(pb)));
192            }
193
194            let mut block_ptr = blocks.as_mut_ptr().cast::<__m128i>();
195            for v in vs {
196                let t: [__m128i; 8] = core::mem::transmute(v);
197                for i in 0..4 {
198                    _mm_storeu_si128(block_ptr.add(i), t[2 * i]);
199                    _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
200                }
201                block_ptr = block_ptr.add(8);
202            }
203        }
204    }
205}
206
207#[inline]
208#[target_feature(enable = "avx2")]
209unsafe fn rounds<R: Rounds>(v: &[__m256i; 3], c: &[__m256i; N]) -> [[__m256i; 4]; N] {
210    let mut vs: [[__m256i; 4]; N] = [[_mm256_setzero_si256(); 4]; N];
211    for i in 0..N {
212        vs[i] = [v[0], v[1], v[2], c[i]];
213    }
214    for _ in 0..R::COUNT {
215        double_quarter_round(&mut vs);
216    }
217
218    for i in 0..N {
219        for j in 0..3 {
220            vs[i][j] = _mm256_add_epi32(vs[i][j], v[j]);
221        }
222        vs[i][3] = _mm256_add_epi32(vs[i][3], c[i]);
223    }
224
225    vs
226}
227
228#[inline]
229#[target_feature(enable = "avx2")]
230unsafe fn double_quarter_round(v: &mut [[__m256i; 4]; N]) {
231    add_xor_rot(v);
232    rows_to_cols(v);
233    add_xor_rot(v);
234    cols_to_rows(v);
235}
236
237/// The goal of this function is to transform the state words from:
238/// ```text
239/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
240/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
241/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
242/// [d0, d1, d2, d3]    [12, 13, 14, 15]
243/// ```
244///
245/// to:
246/// ```text
247/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
248/// [b1, b2, b3, b0] == [ 5,  6,  7,  4]
249/// [c2, c3, c0, c1]    [10, 11,  8,  9]
250/// [d3, d0, d1, d2]    [15, 12, 13, 14]
251/// ```
252///
253/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the
254/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is
255/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle
256/// blocks on the result of `b` being calculated.
257///
258/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are
259/// data-independent: they only access a single column of the state, and thus the order of
260/// the columns does not matter. We therefore instead shuffle the other three state words,
261/// to obtain the following equivalent layout:
262/// ```text
263/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
264/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
265/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
266/// [d2, d3, d0, d1]    [14, 15, 12, 13]
267/// ```
268///
269/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest
270/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014:
271/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
272#[inline]
273#[target_feature(enable = "avx2")]
274unsafe fn rows_to_cols(vs: &mut [[__m256i; 4]; N]) {
275    // c >>>= 32; d >>>= 64; a >>>= 96;
276    for [a, _, c, d] in vs {
277        *c = _mm256_shuffle_epi32(*c, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
278        *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
279        *a = _mm256_shuffle_epi32(*a, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
280    }
281}
282
283/// The goal of this function is to transform the state words from:
284/// ```text
285/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
286/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
287/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
288/// [d2, d3, d0, d1]    [14, 15, 12, 13]
289/// ```
290///
291/// to:
292/// ```text
293/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
294/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
295/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
296/// [d0, d1, d2, d3]    [12, 13, 14, 15]
297/// ```
298///
299/// reversing the transformation of [`rows_to_cols`].
300#[inline]
301#[target_feature(enable = "avx2")]
302unsafe fn cols_to_rows(vs: &mut [[__m256i; 4]; N]) {
303    // c <<<= 32; d <<<= 64; a <<<= 96;
304    for [a, _, c, d] in vs {
305        *c = _mm256_shuffle_epi32(*c, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
306        *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
307        *a = _mm256_shuffle_epi32(*a, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
308    }
309}
310
311#[inline]
312#[target_feature(enable = "avx2")]
313unsafe fn add_xor_rot(vs: &mut [[__m256i; 4]; N]) {
314    let rol16_mask = _mm256_set_epi64x(
315        0x0d0c_0f0e_0908_0b0a,
316        0x0504_0706_0100_0302,
317        0x0d0c_0f0e_0908_0b0a,
318        0x0504_0706_0100_0302,
319    );
320    let rol8_mask = _mm256_set_epi64x(
321        0x0e0d_0c0f_0a09_080b,
322        0x0605_0407_0201_0003,
323        0x0e0d_0c0f_0a09_080b,
324        0x0605_0407_0201_0003,
325    );
326
327    // a += b; d ^= a; d <<<= (16, 16, 16, 16);
328    for [a, b, _, d] in vs.iter_mut() {
329        *a = _mm256_add_epi32(*a, *b);
330        *d = _mm256_xor_si256(*d, *a);
331        *d = _mm256_shuffle_epi8(*d, rol16_mask);
332    }
333
334    // c += d; b ^= c; b <<<= (12, 12, 12, 12);
335    for [_, b, c, d] in vs.iter_mut() {
336        *c = _mm256_add_epi32(*c, *d);
337        *b = _mm256_xor_si256(*b, *c);
338        *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 12), _mm256_srli_epi32(*b, 20));
339    }
340
341    // a += b; d ^= a; d <<<= (8, 8, 8, 8);
342    for [a, b, _, d] in vs.iter_mut() {
343        *a = _mm256_add_epi32(*a, *b);
344        *d = _mm256_xor_si256(*d, *a);
345        *d = _mm256_shuffle_epi8(*d, rol8_mask);
346    }
347
348    // c += d; b ^= c; b <<<= (7, 7, 7, 7);
349    for [_, b, c, d] in vs.iter_mut() {
350        *c = _mm256_add_epi32(*c, *d);
351        *b = _mm256_xor_si256(*b, *c);
352        *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 7), _mm256_srli_epi32(*b, 25));
353    }
354}