1use core::{marker::PhantomData, ptr::NonNull};
7
8use super::{BLOCK_SIZE, Error};
9
10#[derive(Clone, Copy)]
11pub(crate) struct CryptoBuffers<'a> {
12 buffers: UnsafeCryptoBuffers,
13 _marker: PhantomData<&'a mut ()>,
14}
15
16impl<'a> CryptoBuffers<'a> {
17 pub fn new(input: &'a [u8], output: &'a mut [u8]) -> Result<Self, Error> {
18 if input.len() != output.len() {
19 return Err(Error::BuffersNotEqual);
20 }
21 Ok(Self {
22 buffers: UnsafeCryptoBuffers {
23 input: NonNull::from(input),
24 output: NonNull::from(output),
25 },
26 _marker: PhantomData,
27 })
28 }
29
30 pub fn new_in_place(data: &'a mut [u8]) -> Self {
31 let ptr = NonNull::from(data);
32 Self {
33 buffers: UnsafeCryptoBuffers {
34 input: ptr,
35 output: ptr,
36 },
37 _marker: PhantomData,
38 }
39 }
40
41 pub(super) unsafe fn into_inner(self) -> UnsafeCryptoBuffers {
42 self.buffers
43 }
44}
45
46#[derive(Clone, Copy)]
47pub(super) struct UnsafeCryptoBuffers {
48 pub input: NonNull<[u8]>,
49 pub output: NonNull<[u8]>,
50}
51impl UnsafeCryptoBuffers {
52 pub fn in_place(&self) -> bool {
53 self.input.addr() == self.output.addr()
54 }
55
56 #[cfg(aes_dma)]
57 pub(crate) unsafe fn byte_add(self, bytes: usize) -> Self {
58 UnsafeCryptoBuffers {
59 input: unsafe { self.input.byte_add(bytes) },
60 output: unsafe { self.output.byte_add(bytes) },
61 }
62 }
63}
64
65#[derive(Clone)]
67pub struct Ecb;
68impl Ecb {
69 pub(super) fn encrypt_decrypt(
70 &mut self,
71 buffer: UnsafeCryptoBuffers,
72 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
73 ) {
74 buffer.for_data_chunks(BLOCK_SIZE, |input, output, len| {
75 process_block(
76 NonNull::slice_from_raw_parts(input, len),
77 NonNull::slice_from_raw_parts(output, len),
78 )
79 });
80 }
81}
82
83#[derive(Clone)]
85pub struct Cbc {
86 pub(super) iv: [u8; BLOCK_SIZE],
87}
88impl Cbc {
89 pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
91 Self { iv }
92 }
93
94 pub(super) fn encrypt(
95 &mut self,
96 buffer: UnsafeCryptoBuffers,
97 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
98 ) {
99 let iv = NonNull::from(self.iv.as_mut());
100
101 buffer.for_data_chunks(BLOCK_SIZE, |plaintext, ciphertext, len| {
102 xor_into(iv.cast(), plaintext, len);
105
106 process_block(iv, NonNull::slice_from_raw_parts(ciphertext, len));
108
109 copy(iv.cast(), ciphertext, len);
111 });
112 }
113
114 pub(super) fn decrypt(
115 &mut self,
116 buffer: UnsafeCryptoBuffers,
117 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
118 ) {
119 let iv = NonNull::from(self.iv.as_mut()).cast::<u8>();
120
121 if buffer.in_place() {
122 let mut temp_buffer = [0; 16];
123 let temp = NonNull::from(&mut temp_buffer).cast::<u8>();
124 buffer.for_data_chunks(BLOCK_SIZE, |ciphertext, plaintext, len| {
125 copy(temp, ciphertext, len);
126
127 process_block(
129 NonNull::slice_from_raw_parts(ciphertext, len),
130 NonNull::slice_from_raw_parts(plaintext, len),
131 );
132 xor_into(plaintext, iv, len);
133
134 copy(iv, temp, len);
136 });
137 } else {
138 buffer.for_data_chunks(BLOCK_SIZE, |ciphertext, plaintext, len| {
139 process_block(
141 NonNull::slice_from_raw_parts(ciphertext, len),
142 NonNull::slice_from_raw_parts(plaintext, len),
143 );
144 xor_into(plaintext, iv, len);
145
146 copy(iv, ciphertext, len);
148 });
149 }
150 }
151}
152
153#[derive(Clone)]
155pub struct Ofb {
156 pub(super) iv: [u8; BLOCK_SIZE],
157 pub(super) offset: usize,
158}
159impl Ofb {
160 pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
162 Self { iv, offset: 0 }
163 }
164
165 pub(super) fn encrypt_decrypt(
166 &mut self,
167 buffer: UnsafeCryptoBuffers,
168 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
169 ) {
170 let mut offset = self.offset;
171 buffer.for_data_chunks(1, |input, output, _| {
172 if offset == 0 {
173 let iv = NonNull::from(self.iv.as_mut());
175 process_block(iv, iv);
176 }
177
178 unsafe { output.write(input.read() ^ self.iv[offset]) };
180 offset = (offset + 1) % BLOCK_SIZE;
181 });
182 self.offset = offset;
183 }
184
185 #[cfg(aes_dma)]
186 pub(super) fn flush(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
187 let mut offset = self.offset;
188 buffer
189 .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
190 .for_data_chunks(1, |input, output, _| {
191 unsafe { output.write(input.read() ^ self.iv[offset]) };
192 offset += 1;
193 });
194 let flushed = offset - self.offset;
195 self.offset = offset % BLOCK_SIZE;
196 flushed
197 }
198}
199
200#[derive(Clone)]
202pub struct Ctr {
203 pub(super) nonce: [u8; BLOCK_SIZE],
205 pub(super) buffer: [u8; BLOCK_SIZE],
207 pub(super) offset: usize,
208}
209impl Ctr {
210 pub fn new(nonce: [u8; BLOCK_SIZE]) -> Self {
212 Self {
213 nonce,
214 buffer: [0; BLOCK_SIZE],
215 offset: 0,
216 }
217 }
218
219 pub(super) fn encrypt_decrypt(
220 &mut self,
221 buffer: UnsafeCryptoBuffers,
222 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
223 ) {
224 fn increment(nonce: &mut [u8]) {
225 for byte in nonce.iter_mut().rev() {
226 *byte = byte.wrapping_add(1);
227 if *byte != 0 {
228 break;
229 }
230 }
231 }
232
233 let mut offset = self.offset;
234 buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
235 if offset == 0 {
236 let nonce = NonNull::from(self.nonce.as_mut());
237 let buffer = NonNull::from(self.buffer.as_mut());
238 process_block(nonce, buffer);
240 increment(&mut self.nonce);
241 }
242
243 unsafe { ciphertext.write(plaintext.read() ^ self.buffer[offset]) };
245 offset = (offset + 1) % BLOCK_SIZE;
246 });
247 self.offset = offset;
248 }
249
250 #[cfg(aes_dma)]
251 pub(super) fn flush(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
252 let mut offset = self.offset;
253 buffer
254 .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
255 .for_data_chunks(1, |plaintext, ciphertext, _| {
256 unsafe { ciphertext.write(plaintext.read() ^ self.buffer[offset]) };
257 offset += 1;
258 });
259 let flushed = offset - self.offset;
260 self.offset = offset % BLOCK_SIZE;
261 flushed
262 }
263}
264
265#[derive(Clone)]
267pub struct Cfb8 {
268 pub(super) iv: [u8; BLOCK_SIZE],
269}
270impl Cfb8 {
271 pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
273 Self { iv }
274 }
275
276 pub(super) fn encrypt(
277 &mut self,
278 buffer: UnsafeCryptoBuffers,
279 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
280 ) {
281 let mut ov = [0; BLOCK_SIZE];
282 buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
283 process_block(NonNull::from(self.iv.as_mut()), NonNull::from(ov.as_mut()));
284
285 unsafe {
286 let out = ov[0] ^ plaintext.read();
287 ciphertext.write(out);
288
289 self.iv.copy_within(1.., 0);
291 self.iv[BLOCK_SIZE - 1] = out;
292 }
293 });
294 }
295
296 pub(super) fn decrypt(
297 &mut self,
298 buffer: UnsafeCryptoBuffers,
299 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
300 ) {
301 let mut ov = [0; BLOCK_SIZE];
302 buffer.for_data_chunks(1, |ciphertext, plaintext, _| {
303 process_block(NonNull::from(self.iv.as_mut()), NonNull::from(ov.as_mut()));
304
305 unsafe {
306 let c = ciphertext.read();
307 plaintext.write(ov[0] ^ c);
308
309 self.iv.copy_within(1.., 0);
311 self.iv[BLOCK_SIZE - 1] = c;
312 }
313 });
314 }
315}
316
317#[derive(Clone)]
319pub struct Cfb128 {
320 pub(super) iv: [u8; BLOCK_SIZE],
321 pub(super) offset: usize,
322}
323impl Cfb128 {
324 pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
326 Self { iv, offset: 0 }
327 }
328
329 pub(super) fn encrypt(
330 &mut self,
331 buffer: UnsafeCryptoBuffers,
332 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
333 ) {
334 let mut offset = self.offset;
335 buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
336 if offset == 0 {
337 let iv = NonNull::from(self.iv.as_mut());
338 process_block(iv, iv);
339 }
340
341 unsafe {
342 self.iv[offset] ^= plaintext.read();
343 ciphertext.write(self.iv[offset]);
344 }
345 offset = (offset + 1) % BLOCK_SIZE;
346 });
347 self.offset = offset;
348 }
349
350 pub(super) fn decrypt(
351 &mut self,
352 buffer: UnsafeCryptoBuffers,
353 mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
354 ) {
355 let mut offset = self.offset;
356 buffer.for_data_chunks(1, |ciphertext, plaintext, _| {
357 if offset == 0 {
358 let iv = NonNull::from(self.iv.as_mut());
359 process_block(iv, iv);
360 }
361
362 unsafe {
363 let c = ciphertext.read();
364 plaintext.write(self.iv[offset] ^ c);
365 self.iv[offset] = c;
366 }
367 offset = (offset + 1) % BLOCK_SIZE;
368 });
369 self.offset = offset;
370 }
371
372 #[cfg(aes_dma)]
373 pub(super) fn flush_encrypt(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
374 let mut offset = self.offset;
375 buffer
376 .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
377 .for_data_chunks(1, |plaintext, ciphertext, _| {
378 unsafe {
379 self.iv[offset] ^= plaintext.read();
380 ciphertext.write(self.iv[offset]);
381 }
382 offset += 1;
383 });
384 let flushed = offset - self.offset;
385 self.offset = offset % BLOCK_SIZE;
386 flushed
387 }
388
389 #[cfg(aes_dma)]
390 pub(super) fn flush_decrypt(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
391 let mut offset = self.offset;
392 buffer
393 .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
394 .for_data_chunks(1, |ciphertext, plaintext, _| {
395 unsafe {
396 let c = ciphertext.read();
397 plaintext.write(self.iv[offset] ^ c);
398 self.iv[offset] = c;
399 }
400 offset += 1;
401 });
402 let flushed = offset - self.offset;
403 self.offset = offset % BLOCK_SIZE;
404 flushed
405 }
406}
407
408impl UnsafeCryptoBuffers {
411 fn for_data_chunks(
412 self,
413 chunk_size: usize,
414 mut cb: impl FnMut(NonNull<u8>, NonNull<u8>, usize),
415 ) {
416 let input = pointer_chunks(self.input, chunk_size);
417 let output = pointer_chunks(self.output, chunk_size);
418
419 for (input, output, len) in input
420 .zip(output)
421 .map(|((input, len), (output, _))| (input, output, len))
422 {
423 cb(input, output, len)
424 }
425 }
426
427 #[cfg(aes_dma)]
428 fn first_n(self, n: usize) -> UnsafeCryptoBuffers {
429 let len = n.min(self.input.len());
430 Self {
431 input: NonNull::slice_from_raw_parts(self.input.cast(), len),
432 output: NonNull::slice_from_raw_parts(self.output.cast(), len),
433 }
434 }
435}
436
437fn pointer_chunks<T>(
438 ptr: NonNull<[T]>,
439 chunk: usize,
440) -> impl Iterator<Item = (NonNull<T>, usize)> + Clone {
441 let mut len = ptr.len();
442 let mut ptr = ptr.cast::<T>();
443 core::iter::from_fn(move || {
444 let advance = if len > chunk {
445 chunk
446 } else if len > 0 {
447 len
448 } else {
449 return None;
450 };
451
452 let retval = (ptr, advance);
453
454 unsafe { ptr = ptr.add(advance) };
455 len -= advance;
456 Some(retval)
457 })
458}
459
460fn xor_into(mut out: NonNull<u8>, mut a: NonNull<u8>, len: usize) {
461 let end = unsafe { out.add(len) };
462 while out < end {
463 unsafe {
464 out.write(out.read() ^ a.read());
465 a = a.add(1);
466 out = out.add(1);
467 }
468 }
469}
470
471fn copy(out: NonNull<u8>, from: NonNull<u8>, len: usize) {
472 unsafe {
473 out.copy_from(from, len);
474 }
475}