1#![allow(unsafe_op_in_unsafe_fn, reason = "needs triage")]
4#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
5#![allow(clippy::cast_possible_wrap, reason = "needs triage")]
6#![allow(clippy::cast_sign_loss, reason = "needs triage")]
7#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
8
9use crate::{Rounds, Variant};
10use core::marker::PhantomData;
11
12#[cfg(feature = "rng")]
13use crate::ChaChaCore;
14
15#[cfg(feature = "cipher")]
16use crate::{STATE_WORDS, chacha::Block};
17
18#[cfg(feature = "cipher")]
19use cipher::{
20 BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
21 consts::{U4, U64},
22};
23
24#[cfg(target_arch = "x86")]
25use core::arch::x86::*;
26#[cfg(target_arch = "x86_64")]
27use core::arch::x86_64::*;
28
29const PAR_BLOCKS: usize = 4;
31const N: usize = PAR_BLOCKS / 2;
33
34#[inline]
35#[target_feature(enable = "avx2")]
36#[cfg(feature = "cipher")]
37#[cfg_attr(chacha20_backend = "avx512", expect(unused))]
38pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
39where
40 R: Rounds,
41 F: StreamCipherClosure<BlockSize = U64>,
42 V: Variant,
43{
44 let state_ptr = state.as_ptr().cast::<__m128i>();
45 let v = [
46 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
47 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
48 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
49 ];
50 let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
51 c = match size_of::<V::Counter>() {
52 4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)),
53 8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)),
54 _ => unreachable!(),
55 };
56 let mut ctr = [c; N];
57 for i in 0..N {
58 ctr[i] = c;
59 c = match size_of::<V::Counter>() {
60 4 => _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)),
61 8 => _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)),
62 _ => unreachable!(),
63 };
64 }
65 let mut backend = Backend::<R, V> {
66 v,
67 ctr,
68 _pd: PhantomData,
69 };
70
71 f.call(&mut backend);
72
73 state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
74 match size_of::<V::Counter>() {
75 4 => {}
76 8 => state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32,
77 _ => unreachable!(),
78 }
79}
80
81#[inline]
82#[target_feature(enable = "avx2")]
83#[cfg(feature = "rng")]
84pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
85where
86 R: Rounds,
87 V: Variant,
88{
89 let state_ptr = core.state.as_ptr().cast::<__m128i>();
90 let v = [
91 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
92 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
93 _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
94 ];
95 let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
96 c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
97 let mut ctr = [c; N];
98 for i in 0..N {
99 ctr[i] = c;
100 c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
101 }
102 let mut backend = Backend::<R, V> {
103 v,
104 ctr,
105 _pd: PhantomData,
106 };
107
108 backend.rng_gen_par_ks_blocks(buffer);
109
110 core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
111 core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
112}
113
114struct Backend<R: Rounds, V: Variant> {
115 v: [__m256i; 3],
116 ctr: [__m256i; N],
117 _pd: PhantomData<(R, V)>,
118}
119
120#[cfg(feature = "cipher")]
121impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
122 type BlockSize = U64;
123}
124
125#[cfg(feature = "cipher")]
126impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
127 type ParBlocksSize = U4;
128}
129
130#[cfg(feature = "cipher")]
131impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
132 #[inline(always)]
133 fn gen_ks_block(&mut self, block: &mut Block) {
134 unsafe {
135 let res = rounds::<R>(&self.v, &self.ctr);
136 for c in self.ctr.iter_mut() {
137 *c = match size_of::<V::Counter>() {
138 4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1)),
139 8 => _mm256_add_epi64(*c, _mm256_set_epi64x(0, 1, 0, 1)),
140 _ => unreachable!(),
141 };
142 }
143
144 let res0: [__m128i; 8] = core::mem::transmute(res[0]);
145
146 let block_ptr = block.as_mut_ptr().cast::<__m128i>();
147 for i in 0..4 {
148 _mm_storeu_si128(block_ptr.add(i), res0[2 * i]);
149 }
150 }
151 }
152
153 #[inline(always)]
154 fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) {
155 unsafe {
156 let vs = rounds::<R>(&self.v, &self.ctr);
157
158 let pb = PAR_BLOCKS as i32;
159 for c in self.ctr.iter_mut() {
160 *c = match size_of::<V::Counter>() {
161 4 => _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb)),
162 8 => {
163 _mm256_add_epi64(*c, _mm256_set_epi64x(0, i64::from(pb), 0, i64::from(pb)))
164 }
165 _ => unreachable!(),
166 }
167 }
168
169 let mut block_ptr = blocks.as_mut_ptr().cast::<__m128i>();
170 for v in vs {
171 let t: [__m128i; 8] = core::mem::transmute(v);
172 for i in 0..4 {
173 _mm_storeu_si128(block_ptr.add(i), t[2 * i]);
174 _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
175 }
176 block_ptr = block_ptr.add(8);
177 }
178 }
179 }
180}
181
182#[cfg(feature = "rng")]
183impl<R: Rounds, V: Variant> Backend<R, V> {
184 #[inline(always)]
185 fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) {
186 unsafe {
187 let vs = rounds::<R>(&self.v, &self.ctr);
188
189 let pb = PAR_BLOCKS as i32;
190 for c in self.ctr.iter_mut() {
191 *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, i64::from(pb), 0, i64::from(pb)));
192 }
193
194 let mut block_ptr = blocks.as_mut_ptr().cast::<__m128i>();
195 for v in vs {
196 let t: [__m128i; 8] = core::mem::transmute(v);
197 for i in 0..4 {
198 _mm_storeu_si128(block_ptr.add(i), t[2 * i]);
199 _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
200 }
201 block_ptr = block_ptr.add(8);
202 }
203 }
204 }
205}
206
207#[inline]
208#[target_feature(enable = "avx2")]
209unsafe fn rounds<R: Rounds>(v: &[__m256i; 3], c: &[__m256i; N]) -> [[__m256i; 4]; N] {
210 let mut vs: [[__m256i; 4]; N] = [[_mm256_setzero_si256(); 4]; N];
211 for i in 0..N {
212 vs[i] = [v[0], v[1], v[2], c[i]];
213 }
214 for _ in 0..R::COUNT {
215 double_quarter_round(&mut vs);
216 }
217
218 for i in 0..N {
219 for j in 0..3 {
220 vs[i][j] = _mm256_add_epi32(vs[i][j], v[j]);
221 }
222 vs[i][3] = _mm256_add_epi32(vs[i][3], c[i]);
223 }
224
225 vs
226}
227
228#[inline]
229#[target_feature(enable = "avx2")]
230unsafe fn double_quarter_round(v: &mut [[__m256i; 4]; N]) {
231 add_xor_rot(v);
232 rows_to_cols(v);
233 add_xor_rot(v);
234 cols_to_rows(v);
235}
236
237#[inline]
273#[target_feature(enable = "avx2")]
274unsafe fn rows_to_cols(vs: &mut [[__m256i; 4]; N]) {
275 for [a, _, c, d] in vs {
277 *c = _mm256_shuffle_epi32(*c, 0b_00_11_10_01); *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm256_shuffle_epi32(*a, 0b_10_01_00_11); }
281}
282
283#[inline]
301#[target_feature(enable = "avx2")]
302unsafe fn cols_to_rows(vs: &mut [[__m256i; 4]; N]) {
303 for [a, _, c, d] in vs {
305 *c = _mm256_shuffle_epi32(*c, 0b_10_01_00_11); *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm256_shuffle_epi32(*a, 0b_00_11_10_01); }
309}
310
311#[inline]
312#[target_feature(enable = "avx2")]
313unsafe fn add_xor_rot(vs: &mut [[__m256i; 4]; N]) {
314 let rol16_mask = _mm256_set_epi64x(
315 0x0d0c_0f0e_0908_0b0a,
316 0x0504_0706_0100_0302,
317 0x0d0c_0f0e_0908_0b0a,
318 0x0504_0706_0100_0302,
319 );
320 let rol8_mask = _mm256_set_epi64x(
321 0x0e0d_0c0f_0a09_080b,
322 0x0605_0407_0201_0003,
323 0x0e0d_0c0f_0a09_080b,
324 0x0605_0407_0201_0003,
325 );
326
327 for [a, b, _, d] in vs.iter_mut() {
329 *a = _mm256_add_epi32(*a, *b);
330 *d = _mm256_xor_si256(*d, *a);
331 *d = _mm256_shuffle_epi8(*d, rol16_mask);
332 }
333
334 for [_, b, c, d] in vs.iter_mut() {
336 *c = _mm256_add_epi32(*c, *d);
337 *b = _mm256_xor_si256(*b, *c);
338 *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 12), _mm256_srli_epi32(*b, 20));
339 }
340
341 for [a, b, _, d] in vs.iter_mut() {
343 *a = _mm256_add_epi32(*a, *b);
344 *d = _mm256_xor_si256(*d, *a);
345 *d = _mm256_shuffle_epi8(*d, rol8_mask);
346 }
347
348 for [_, b, c, d] in vs.iter_mut() {
350 *c = _mm256_add_epi32(*c, *d);
351 *b = _mm256_xor_si256(*b, *c);
352 *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 7), _mm256_srli_epi32(*b, 25));
353 }
354}