1#![allow(unsafe_op_in_unsafe_fn, reason = "needs triage")]
4#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
5#![allow(clippy::cast_possible_wrap, reason = "needs triage")]
6#![allow(clippy::cast_sign_loss, reason = "needs triage")]
7#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
8
9use crate::{Rounds, Variant};
10
11#[cfg(feature = "rng")]
12use crate::ChaChaCore;
13
14#[cfg(feature = "cipher")]
15use crate::{STATE_WORDS, chacha::Block};
16#[cfg(feature = "cipher")]
17use cipher::{
18 BlockSizeUser, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
19 consts::{U4, U64},
20};
21use core::marker::PhantomData;
22
23#[cfg(target_arch = "x86")]
24use core::arch::x86::*;
25#[cfg(target_arch = "x86_64")]
26use core::arch::x86_64::*;
27
28const PAR_BLOCKS: usize = 4;
29
30#[inline]
31#[target_feature(enable = "sse2")]
32#[cfg(feature = "cipher")]
33pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
34where
35 R: Rounds,
36 F: StreamCipherClosure<BlockSize = U64>,
37 V: Variant,
38{
39 let state_ptr = state.as_ptr().cast::<__m128i>();
40 let mut backend = Backend::<R, V> {
41 v: [
42 _mm_loadu_si128(state_ptr.add(0)),
43 _mm_loadu_si128(state_ptr.add(1)),
44 _mm_loadu_si128(state_ptr.add(2)),
45 _mm_loadu_si128(state_ptr.add(3)),
46 ],
47 _pd: PhantomData,
48 };
49
50 f.call(&mut backend);
51
52 state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
53 if size_of::<V::Counter>() == 8 {
54 state[13] = _mm_extract_epi32(backend.v[3], 1) as u32;
55 }
56}
57
58struct Backend<R: Rounds, V: Variant> {
59 v: [__m128i; 4],
60 _pd: PhantomData<(R, V)>,
61}
62
63#[cfg(feature = "cipher")]
64impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
65 type BlockSize = U64;
66}
67
68#[cfg(feature = "cipher")]
69impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
70 type ParBlocksSize = U4;
71}
72
73#[cfg(feature = "cipher")]
74impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
75 #[inline(always)]
76 fn gen_ks_block(&mut self, block: &mut Block) {
77 unsafe {
78 let res = rounds::<R, V>(&self.v);
79 self.v[3] = match size_of::<V::Counter>() {
80 4 => _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1)),
81 8 => _mm_add_epi64(self.v[3], _mm_set_epi64x(0, 1)),
82 _ => unreachable!(),
83 };
84
85 let block_ptr = block.as_mut_ptr().cast::<__m128i>();
86 for i in 0..4 {
87 _mm_storeu_si128(block_ptr.add(i), res[0][i]);
88 }
89 }
90 }
91 #[inline(always)]
92 fn gen_par_ks_blocks(&mut self, blocks: &mut cipher::ParBlocks<Self>) {
93 unsafe {
94 let res = rounds::<R, V>(&self.v);
95 self.v[3] = match size_of::<V::Counter>() {
96 4 => _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)),
97 8 => _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64)),
98 _ => unreachable!(),
99 };
100
101 let blocks_ptr = blocks.as_mut_ptr().cast::<__m128i>();
102 for block in 0..PAR_BLOCKS {
103 for i in 0..4 {
104 _mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
105 }
106 }
107 }
108 }
109}
110
111#[inline]
112#[target_feature(enable = "sse2")]
113#[cfg(feature = "rng")]
114pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
115where
116 R: Rounds,
117 V: Variant,
118{
119 let state_ptr = core.state.as_ptr().cast::<__m128i>();
120 let mut backend = Backend::<R, V> {
121 v: [
122 _mm_loadu_si128(state_ptr.add(0)),
123 _mm_loadu_si128(state_ptr.add(1)),
124 _mm_loadu_si128(state_ptr.add(2)),
125 _mm_loadu_si128(state_ptr.add(3)),
126 ],
127 _pd: PhantomData,
128 };
129
130 backend.gen_ks_blocks(buffer);
131
132 core.state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
133 core.state[13] = _mm_extract_epi32(backend.v[3], 1) as u32;
134}
135
136#[cfg(feature = "rng")]
137impl<R: Rounds, V: Variant> Backend<R, V> {
138 #[inline(always)]
139 fn gen_ks_blocks(&mut self, block: &mut [u32; 64]) {
140 const _: () = assert!(4 * PAR_BLOCKS * size_of::<__m128i>() == size_of::<[u32; 64]>());
141 unsafe {
142 let res = rounds::<R, V>(&self.v);
143 self.v[3] = _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64));
144
145 let blocks_ptr = block.as_mut_ptr().cast::<__m128i>();
146 for block in 0..PAR_BLOCKS {
147 for i in 0..4 {
148 _mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
149 }
150 }
151 }
152 }
153}
154
155#[inline]
156#[target_feature(enable = "sse2")]
157unsafe fn rounds<R: Rounds, V: Variant>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] {
158 let mut res = [*v; 4];
159 for block in 1..PAR_BLOCKS {
160 res[block][3] = match size_of::<V::Counter>() {
161 4 => _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)),
162 8 => _mm_add_epi64(res[block][3], _mm_set_epi64x(0, block as i64)),
163 _ => unreachable!(),
164 }
165 }
166
167 for _ in 0..R::COUNT {
168 double_quarter_round(&mut res);
169 }
170
171 for block in 0..PAR_BLOCKS {
172 for i in 0..3 {
173 res[block][i] = _mm_add_epi32(res[block][i], v[i]);
174 }
175 let ctr = match size_of::<V::Counter>() {
176 4 => _mm_add_epi32(v[3], _mm_set_epi32(0, 0, 0, block as i32)),
177 8 => _mm_add_epi64(v[3], _mm_set_epi64x(0, block as i64)),
178 _ => unreachable!(),
179 };
180 res[block][3] = _mm_add_epi32(res[block][3], ctr);
181 }
182
183 res
184}
185
186#[inline]
187#[target_feature(enable = "sse2")]
188unsafe fn double_quarter_round(v: &mut [[__m128i; 4]; PAR_BLOCKS]) {
189 add_xor_rot(v);
190 rows_to_cols(v);
191 add_xor_rot(v);
192 cols_to_rows(v);
193}
194
195#[inline]
231#[target_feature(enable = "sse2")]
232unsafe fn rows_to_cols(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
233 for [a, _, c, d] in blocks.iter_mut() {
234 *c = _mm_shuffle_epi32(*c, 0b_00_11_10_01); *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm_shuffle_epi32(*a, 0b_10_01_00_11); }
239}
240
241#[inline]
259#[target_feature(enable = "sse2")]
260unsafe fn cols_to_rows(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
261 for [a, _, c, d] in blocks.iter_mut() {
262 *c = _mm_shuffle_epi32(*c, 0b_10_01_00_11); *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm_shuffle_epi32(*a, 0b_00_11_10_01); }
267}
268
269#[inline]
270#[target_feature(enable = "sse2")]
271unsafe fn add_xor_rot(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
272 for [a, b, c, d] in blocks.iter_mut() {
273 *a = _mm_add_epi32(*a, *b);
275 *d = _mm_xor_si128(*d, *a);
276 *d = _mm_xor_si128(_mm_slli_epi32(*d, 16), _mm_srli_epi32(*d, 16));
277
278 *c = _mm_add_epi32(*c, *d);
280 *b = _mm_xor_si128(*b, *c);
281 *b = _mm_xor_si128(_mm_slli_epi32(*b, 12), _mm_srli_epi32(*b, 20));
282
283 *a = _mm_add_epi32(*a, *b);
285 *d = _mm_xor_si128(*d, *a);
286 *d = _mm_xor_si128(_mm_slli_epi32(*d, 8), _mm_srli_epi32(*d, 24));
287
288 *c = _mm_add_epi32(*c, *d);
290 *b = _mm_xor_si128(*b, *c);
291 *b = _mm_xor_si128(_mm_slli_epi32(*b, 7), _mm_srli_epi32(*b, 25));
292 }
293}