1use core::{marker::PhantomData, ptr::NonNull};
7
8use super::{BLOCK_SIZE, Error};
9
10#[derive(Clone, Copy)]
11pub(crate) struct CryptoBuffers<'a> {
12    buffers: UnsafeCryptoBuffers,
13    _marker: PhantomData<&'a mut ()>,
14}
15
16impl<'a> CryptoBuffers<'a> {
17    pub fn new(input: &'a [u8], output: &'a mut [u8]) -> Result<Self, Error> {
18        if input.len() != output.len() {
19            return Err(Error::BuffersNotEqual);
20        }
21        Ok(Self {
22            buffers: UnsafeCryptoBuffers {
23                input: NonNull::from(input),
24                output: NonNull::from(output),
25            },
26            _marker: PhantomData,
27        })
28    }
29
30    pub fn new_in_place(data: &'a mut [u8]) -> Self {
31        let ptr = NonNull::from(data);
32        Self {
33            buffers: UnsafeCryptoBuffers {
34                input: ptr,
35                output: ptr,
36            },
37            _marker: PhantomData,
38        }
39    }
40
41    pub(super) unsafe fn into_inner(self) -> UnsafeCryptoBuffers {
42        self.buffers
43    }
44}
45
46#[derive(Clone, Copy)]
47pub(super) struct UnsafeCryptoBuffers {
48    pub input: NonNull<[u8]>,
49    pub output: NonNull<[u8]>,
50}
51impl UnsafeCryptoBuffers {
52    pub fn in_place(&self) -> bool {
53        self.input.addr() == self.output.addr()
54    }
55
56    #[cfg(aes_dma)]
57    pub(crate) unsafe fn byte_add(self, bytes: usize) -> Self {
58        UnsafeCryptoBuffers {
59            input: unsafe { self.input.byte_add(bytes) },
60            output: unsafe { self.output.byte_add(bytes) },
61        }
62    }
63}
64
65#[derive(Clone)]
67pub struct Ecb;
68impl Ecb {
69    pub(super) fn encrypt_decrypt(
70        &mut self,
71        buffer: UnsafeCryptoBuffers,
72        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
73    ) {
74        buffer.for_data_chunks(BLOCK_SIZE, |input, output, len| {
75            process_block(
76                NonNull::slice_from_raw_parts(input, len),
77                NonNull::slice_from_raw_parts(output, len),
78            )
79        });
80    }
81}
82
83#[derive(Clone)]
85pub struct Cbc {
86    pub(super) iv: [u8; BLOCK_SIZE],
87}
88impl Cbc {
89    pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
91        Self { iv }
92    }
93
94    pub(super) fn encrypt(
95        &mut self,
96        buffer: UnsafeCryptoBuffers,
97        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
98    ) {
99        let iv = NonNull::from(self.iv.as_mut());
100
101        buffer.for_data_chunks(BLOCK_SIZE, |plaintext, ciphertext, len| {
102            xor_into(iv.cast(), plaintext, len);
105
106            process_block(iv, NonNull::slice_from_raw_parts(ciphertext, len));
108
109            copy(iv.cast(), ciphertext, len);
111        });
112    }
113
114    pub(super) fn decrypt(
115        &mut self,
116        buffer: UnsafeCryptoBuffers,
117        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
118    ) {
119        let iv = NonNull::from(self.iv.as_mut()).cast::<u8>();
120
121        if buffer.in_place() {
122            let mut temp_buffer = [0; 16];
123            let temp = NonNull::from(&mut temp_buffer).cast::<u8>();
124            buffer.for_data_chunks(BLOCK_SIZE, |ciphertext, plaintext, len| {
125                copy(temp, ciphertext, len);
126
127                process_block(
129                    NonNull::slice_from_raw_parts(ciphertext, len),
130                    NonNull::slice_from_raw_parts(plaintext, len),
131                );
132                xor_into(plaintext, iv, len);
133
134                copy(iv, temp, len);
136            });
137        } else {
138            buffer.for_data_chunks(BLOCK_SIZE, |ciphertext, plaintext, len| {
139                process_block(
141                    NonNull::slice_from_raw_parts(ciphertext, len),
142                    NonNull::slice_from_raw_parts(plaintext, len),
143                );
144                xor_into(plaintext, iv, len);
145
146                copy(iv, ciphertext, len);
148            });
149        }
150    }
151}
152
153#[derive(Clone)]
155pub struct Ofb {
156    pub(super) iv: [u8; BLOCK_SIZE],
157    pub(super) offset: usize,
158}
159impl Ofb {
160    pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
162        Self { iv, offset: 0 }
163    }
164
165    pub(super) fn encrypt_decrypt(
166        &mut self,
167        buffer: UnsafeCryptoBuffers,
168        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
169    ) {
170        let mut offset = self.offset;
171        buffer.for_data_chunks(1, |input, output, _| {
172            if offset == 0 {
173                let iv = NonNull::from(self.iv.as_mut());
175                process_block(iv, iv);
176            }
177
178            unsafe { output.write(input.read() ^ self.iv[offset]) };
180            offset = (offset + 1) % BLOCK_SIZE;
181        });
182        self.offset = offset;
183    }
184
185    #[cfg(aes_dma)]
186    pub(super) fn flush(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
187        let mut offset = self.offset;
188        buffer
189            .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
190            .for_data_chunks(1, |input, output, _| {
191                unsafe { output.write(input.read() ^ self.iv[offset]) };
192                offset += 1;
193            });
194        let flushed = offset - self.offset;
195        self.offset = offset % BLOCK_SIZE;
196        flushed
197    }
198}
199
200#[derive(Clone)]
202pub struct Ctr {
203    pub(super) nonce: [u8; BLOCK_SIZE],
205    pub(super) buffer: [u8; BLOCK_SIZE],
207    pub(super) offset: usize,
208}
209impl Ctr {
210    pub fn new(nonce: [u8; BLOCK_SIZE]) -> Self {
212        Self {
213            nonce,
214            buffer: [0; BLOCK_SIZE],
215            offset: 0,
216        }
217    }
218
219    pub(super) fn encrypt_decrypt(
220        &mut self,
221        buffer: UnsafeCryptoBuffers,
222        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
223    ) {
224        fn increment(nonce: &mut [u8]) {
225            for byte in nonce.iter_mut().rev() {
226                *byte = byte.wrapping_add(1);
227                if *byte != 0 {
228                    break;
229                }
230            }
231        }
232
233        let mut offset = self.offset;
234        buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
235            if offset == 0 {
236                let nonce = NonNull::from(self.nonce.as_mut());
237                let buffer = NonNull::from(self.buffer.as_mut());
238                process_block(nonce, buffer);
240                increment(&mut self.nonce);
241            }
242
243            unsafe { ciphertext.write(plaintext.read() ^ self.buffer[offset]) };
245            offset = (offset + 1) % BLOCK_SIZE;
246        });
247        self.offset = offset;
248    }
249
250    #[cfg(aes_dma)]
251    pub(super) fn flush(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
252        let mut offset = self.offset;
253        buffer
254            .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
255            .for_data_chunks(1, |plaintext, ciphertext, _| {
256                unsafe { ciphertext.write(plaintext.read() ^ self.buffer[offset]) };
257                offset += 1;
258            });
259        let flushed = offset - self.offset;
260        self.offset = offset % BLOCK_SIZE;
261        flushed
262    }
263}
264
265#[derive(Clone)]
267pub struct Cfb8 {
268    pub(super) iv: [u8; BLOCK_SIZE],
269}
270impl Cfb8 {
271    pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
273        Self { iv }
274    }
275
276    pub(super) fn encrypt(
277        &mut self,
278        buffer: UnsafeCryptoBuffers,
279        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
280    ) {
281        let mut ov = [0; BLOCK_SIZE];
282        buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
283            process_block(NonNull::from(self.iv.as_mut()), NonNull::from(ov.as_mut()));
284
285            unsafe {
286                let out = ov[0] ^ plaintext.read();
287                ciphertext.write(out);
288
289                self.iv.copy_within(1.., 0);
291                self.iv[BLOCK_SIZE - 1] = out;
292            }
293        });
294    }
295
296    pub(super) fn decrypt(
297        &mut self,
298        buffer: UnsafeCryptoBuffers,
299        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
300    ) {
301        let mut ov = [0; BLOCK_SIZE];
302        buffer.for_data_chunks(1, |ciphertext, plaintext, _| {
303            process_block(NonNull::from(self.iv.as_mut()), NonNull::from(ov.as_mut()));
304
305            unsafe {
306                let c = ciphertext.read();
307                plaintext.write(ov[0] ^ c);
308
309                self.iv.copy_within(1.., 0);
311                self.iv[BLOCK_SIZE - 1] = c;
312            }
313        });
314    }
315}
316
317#[derive(Clone)]
319pub struct Cfb128 {
320    pub(super) iv: [u8; BLOCK_SIZE],
321    pub(super) offset: usize,
322}
323impl Cfb128 {
324    pub fn new(iv: [u8; BLOCK_SIZE]) -> Self {
326        Self { iv, offset: 0 }
327    }
328
329    pub(super) fn encrypt(
330        &mut self,
331        buffer: UnsafeCryptoBuffers,
332        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
333    ) {
334        let mut offset = self.offset;
335        buffer.for_data_chunks(1, |plaintext, ciphertext, _| {
336            if offset == 0 {
337                let iv = NonNull::from(self.iv.as_mut());
338                process_block(iv, iv);
339            }
340
341            unsafe {
342                self.iv[offset] ^= plaintext.read();
343                ciphertext.write(self.iv[offset]);
344            }
345            offset = (offset + 1) % BLOCK_SIZE;
346        });
347        self.offset = offset;
348    }
349
350    pub(super) fn decrypt(
351        &mut self,
352        buffer: UnsafeCryptoBuffers,
353        mut process_block: impl FnMut(NonNull<[u8]>, NonNull<[u8]>),
354    ) {
355        let mut offset = self.offset;
356        buffer.for_data_chunks(1, |ciphertext, plaintext, _| {
357            if offset == 0 {
358                let iv = NonNull::from(self.iv.as_mut());
359                process_block(iv, iv);
360            }
361
362            unsafe {
363                let c = ciphertext.read();
364                plaintext.write(self.iv[offset] ^ c);
365                self.iv[offset] = c;
366            }
367            offset = (offset + 1) % BLOCK_SIZE;
368        });
369        self.offset = offset;
370    }
371
372    #[cfg(aes_dma)]
373    pub(super) fn flush_encrypt(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
374        let mut offset = self.offset;
375        buffer
376            .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
377            .for_data_chunks(1, |plaintext, ciphertext, _| {
378                unsafe {
379                    self.iv[offset] ^= plaintext.read();
380                    ciphertext.write(self.iv[offset]);
381                }
382                offset += 1;
383            });
384        let flushed = offset - self.offset;
385        self.offset = offset % BLOCK_SIZE;
386        flushed
387    }
388
389    #[cfg(aes_dma)]
390    pub(super) fn flush_decrypt(&mut self, buffer: UnsafeCryptoBuffers) -> usize {
391        let mut offset = self.offset;
392        buffer
393            .first_n((BLOCK_SIZE - offset) % BLOCK_SIZE)
394            .for_data_chunks(1, |ciphertext, plaintext, _| {
395                unsafe {
396                    let c = ciphertext.read();
397                    plaintext.write(self.iv[offset] ^ c);
398                    self.iv[offset] = c;
399                }
400                offset += 1;
401            });
402        let flushed = offset - self.offset;
403        self.offset = offset % BLOCK_SIZE;
404        flushed
405    }
406}
407
408impl UnsafeCryptoBuffers {
411    fn for_data_chunks(
412        self,
413        chunk_size: usize,
414        mut cb: impl FnMut(NonNull<u8>, NonNull<u8>, usize),
415    ) {
416        let input = pointer_chunks(self.input, chunk_size);
417        let output = pointer_chunks(self.output, chunk_size);
418
419        for (input, output, len) in input
420            .zip(output)
421            .map(|((input, len), (output, _))| (input, output, len))
422        {
423            cb(input, output, len)
424        }
425    }
426
427    #[cfg(aes_dma)]
428    fn first_n(self, n: usize) -> UnsafeCryptoBuffers {
429        let len = n.min(self.input.len());
430        Self {
431            input: NonNull::slice_from_raw_parts(self.input.cast(), len),
432            output: NonNull::slice_from_raw_parts(self.output.cast(), len),
433        }
434    }
435}
436
437fn pointer_chunks<T>(
438    ptr: NonNull<[T]>,
439    chunk: usize,
440) -> impl Iterator<Item = (NonNull<T>, usize)> + Clone {
441    let mut len = ptr.len();
442    let mut ptr = ptr.cast::<T>();
443    core::iter::from_fn(move || {
444        let advance = if len > chunk {
445            chunk
446        } else if len > 0 {
447            len
448        } else {
449            return None;
450        };
451
452        let retval = (ptr, advance);
453
454        unsafe { ptr = ptr.add(advance) };
455        len -= advance;
456        Some(retval)
457    })
458}
459
460fn xor_into(mut out: NonNull<u8>, mut a: NonNull<u8>, len: usize) {
461    let end = unsafe { out.add(len) };
462    while out < end {
463        unsafe {
464            out.write(out.read() ^ a.read());
465            a = a.add(1);
466            out = out.add(1);
467        }
468    }
469}
470
471fn copy(out: NonNull<u8>, from: NonNull<u8>, len: usize) {
472    unsafe {
473        out.copy_from(from, len);
474    }
475}