esp_hal/rsa/
mod.rs

1//! # RSA (Rivest–Shamir–Adleman) accelerator.
2//!
3//! ## Overview
4//!
5//! The RSA accelerator provides hardware support for high precision computation
6//! used in various RSA asymmetric cipher algorithms by significantly reducing
7//! their software complexity. Compared with RSA algorithms implemented solely
8//! in software, this hardware accelerator can speed up RSA algorithms
9//! significantly.
10//!
11//! ## Configuration
12//!
13//! The RSA accelerator also supports operands of different lengths, which
14//! provides more flexibility during the computation.
15
16use core::{marker::PhantomData, ptr::NonNull, task::Poll};
17
18use portable_atomic::{AtomicBool, Ordering};
19use procmacros::{handler, ram};
20
21use crate::{
22    Async,
23    Blocking,
24    DriverMode,
25    asynch::AtomicWaker,
26    interrupt::InterruptHandler,
27    pac,
28    peripherals::{Interrupt, RSA},
29    system::{Cpu, GenericPeripheralGuard, Peripheral as PeripheralEnable},
30    trm_markdown_link,
31    work_queue::{self, Status, VTable, WorkQueue, WorkQueueDriver, WorkQueueFrontend},
32};
33
34/// RSA peripheral driver.
35pub struct Rsa<'d, Dm: DriverMode> {
36    rsa: RSA<'d>,
37    phantom: PhantomData<Dm>,
38    _guard: GenericPeripheralGuard<{ PeripheralEnable::Rsa as u8 }>,
39}
40
41// There are two distinct peripheral versions: ESP32, and all else. There is a naming split in the
42// later devices, and they use different (memory size, operand size increment) parameters, but they
43// are largely the same.
44
45/// How many words are there in an operand size increment.
46///
47/// I.e. if the RSA hardware works with operands of 512, 1024, 1536, ... bits, the increment is 512
48/// bits, or 16 words.
49const WORDS_PER_INCREMENT: u32 = property!("rsa.size_increment") / 32;
50
51impl<'d> Rsa<'d, Blocking> {
52    /// Create a new instance in [Blocking] mode.
53    ///
54    /// Optionally an interrupt handler can be bound.
55    pub fn new(rsa: RSA<'d>) -> Self {
56        let guard = GenericPeripheralGuard::new();
57
58        let this = Self {
59            rsa,
60            phantom: PhantomData,
61            _guard: guard,
62        };
63
64        while !this.ready() {}
65
66        this
67    }
68
69    /// Reconfigures the RSA driver to operate in asynchronous mode.
70    pub fn into_async(mut self) -> Rsa<'d, Async> {
71        self.set_interrupt_handler(rsa_interrupt_handler);
72        self.enable_disable_interrupt(true);
73
74        Rsa {
75            rsa: self.rsa,
76            phantom: PhantomData,
77            _guard: self._guard,
78        }
79    }
80
81    /// Enables/disables rsa interrupt.
82    ///
83    /// When enabled rsa peripheral would generate an interrupt when a operation
84    /// is finished.
85    pub fn enable_disable_interrupt(&mut self, enable: bool) {
86        self.internal_enable_disable_interrupt(enable);
87    }
88
89    /// Registers an interrupt handler for the RSA peripheral.
90    ///
91    /// Note that this will replace any previously registered interrupt
92    /// handlers.
93    #[instability::unstable]
94    pub fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
95        self.rsa.disable_peri_interrupt();
96
97        self.rsa.bind_peri_interrupt(handler.handler());
98        self.rsa.enable_peri_interrupt(handler.priority());
99    }
100}
101
102impl crate::private::Sealed for Rsa<'_, Blocking> {}
103
104#[instability::unstable]
105impl crate::interrupt::InterruptConfigurable for Rsa<'_, Blocking> {
106    fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
107        self.set_interrupt_handler(handler);
108    }
109}
110
111impl<'d> Rsa<'d, Async> {
112    /// Create a new instance in [crate::Blocking] mode.
113    pub fn into_blocking(self) -> Rsa<'d, Blocking> {
114        self.internal_enable_disable_interrupt(false);
115        self.rsa.disable_peri_interrupt();
116
117        crate::interrupt::disable(Cpu::current(), Interrupt::RSA);
118        Rsa {
119            rsa: self.rsa,
120            phantom: PhantomData,
121            _guard: self._guard,
122        }
123    }
124}
125
126impl<'d, Dm: DriverMode> Rsa<'d, Dm> {
127    fn internal_enable_disable_interrupt(&self, enable: bool) {
128        cfg_if::cfg_if! {
129            if #[cfg(esp32)] {
130                // Can't seem to actually disable the interrupt, but esp-idf still writes the register
131                self.regs().interrupt().write(|w| w.interrupt().bit(enable));
132            } else {
133                self.regs().int_ena().write(|w| w.int_ena().bit(enable));
134            }
135        }
136    }
137
138    fn regs(&self) -> &pac::rsa::RegisterBlock {
139        self.rsa.register_block()
140    }
141
142    /// After the RSA accelerator is released from reset, the memory blocks
143    /// needs to be initialized, only after that peripheral should be used.
144    /// This function would return without an error if the memory is
145    /// initialized.
146    fn ready(&self) -> bool {
147        cfg_if::cfg_if! {
148            if #[cfg(any(esp32, esp32s2, esp32s3))] {
149                self.regs().clean().read().clean().bit_is_set()
150            } else {
151                self.regs().query_clean().read().query_clean().bit_is_set()
152            }
153        }
154    }
155
156    /// Starts the modular exponentiation operation.
157    fn start_modexp(&self) {
158        cfg_if::cfg_if! {
159            if #[cfg(any(esp32, esp32s2, esp32s3))] {
160                self.regs()
161                    .modexp_start()
162                    .write(|w| w.modexp_start().set_bit());
163            } else {
164                self.regs()
165                    .set_start_modexp()
166                    .write(|w| w.set_start_modexp().set_bit());
167            }
168        }
169    }
170
171    /// Starts the multiplication operation.
172    fn start_multi(&self) {
173        cfg_if::cfg_if! {
174            if #[cfg(any(esp32, esp32s2, esp32s3))] {
175                self.regs().mult_start().write(|w| w.mult_start().set_bit());
176            } else {
177                self.regs()
178                    .set_start_mult()
179                    .write(|w| w.set_start_mult().set_bit());
180            }
181        }
182    }
183
184    /// Starts the modular multiplication operation.
185    fn start_modmulti(&self) {
186        cfg_if::cfg_if! {
187            if #[cfg(esp32)] {
188                // modular-ness is encoded in the multi_mode register value
189                self.start_multi();
190            } else if #[cfg(any(esp32s2, esp32s3))] {
191                self.regs()
192                    .modmult_start()
193                    .write(|w| w.modmult_start().set_bit());
194            } else {
195                self.regs()
196                    .set_start_modmult()
197                    .write(|w| w.set_start_modmult().set_bit());
198            }
199        }
200    }
201
202    /// Clears the RSA interrupt flag.
203    fn clear_interrupt(&mut self) {
204        cfg_if::cfg_if! {
205            if #[cfg(esp32)] {
206                self.regs().interrupt().write(|w| w.interrupt().set_bit());
207            } else {
208                self.regs().int_clr().write(|w| w.int_clr().set_bit());
209            }
210        }
211    }
212
213    /// Checks if the RSA peripheral is idle.
214    fn is_idle(&self) -> bool {
215        cfg_if::cfg_if! {
216            if #[cfg(esp32)] {
217                self.regs().interrupt().read().interrupt().bit_is_set()
218            } else if #[cfg(any(esp32s2, esp32s3))] {
219                self.regs().idle().read().idle().bit_is_set()
220            } else {
221                self.regs().query_idle().read().query_idle().bit_is_set()
222            }
223        }
224    }
225
226    fn wait_for_idle(&mut self) {
227        while !self.is_idle() {}
228        self.clear_interrupt();
229    }
230
231    /// Writes the result size of the multiplication.
232    fn write_multi_mode(&mut self, mode: u32, modular: bool) {
233        let mode = if cfg!(esp32) && !modular {
234            const NON_MODULAR: u32 = 8;
235            mode | NON_MODULAR
236        } else {
237            mode
238        };
239
240        cfg_if::cfg_if! {
241            if #[cfg(esp32)] {
242                self.regs().mult_mode().write(|w| unsafe { w.bits(mode) });
243            } else {
244                self.regs().mode().write(|w| unsafe { w.bits(mode) });
245            }
246        }
247    }
248
249    /// Writes the result size of the modular exponentiation.
250    fn write_modexp_mode(&mut self, mode: u32) {
251        cfg_if::cfg_if! {
252            if #[cfg(esp32)] {
253                self.regs().modexp_mode().write(|w| unsafe { w.bits(mode) });
254            } else {
255                self.regs().mode().write(|w| unsafe { w.bits(mode) });
256            }
257        }
258    }
259
260    fn write_operand_b(&mut self, operand: &[u32]) {
261        for (reg, op) in self.regs().y_mem_iter().zip(operand.iter().copied()) {
262            reg.write(|w| unsafe { w.bits(op) });
263        }
264    }
265
266    fn write_modulus(&mut self, modulus: &[u32]) {
267        for (reg, op) in self.regs().m_mem_iter().zip(modulus.iter().copied()) {
268            reg.write(|w| unsafe { w.bits(op) });
269        }
270    }
271
272    fn write_mprime(&mut self, m_prime: u32) {
273        self.regs().m_prime().write(|w| unsafe { w.bits(m_prime) });
274    }
275
276    fn write_operand_a(&mut self, operand: &[u32]) {
277        for (reg, op) in self.regs().x_mem_iter().zip(operand.iter().copied()) {
278            reg.write(|w| unsafe { w.bits(op) });
279        }
280    }
281
282    fn write_multi_operand_b(&mut self, operand: &[u32]) {
283        for (reg, op) in self
284            .regs()
285            .z_mem_iter()
286            .skip(operand.len())
287            .zip(operand.iter().copied())
288        {
289            reg.write(|w| unsafe { w.bits(op) });
290        }
291    }
292
293    fn write_r(&mut self, r: &[u32]) {
294        for (reg, op) in self.regs().z_mem_iter().zip(r.iter().copied()) {
295            reg.write(|w| unsafe { w.bits(op) });
296        }
297    }
298
299    fn read_out(&self, outbuf: &mut [u32]) {
300        for (reg, op) in self.regs().z_mem_iter().zip(outbuf.iter_mut()) {
301            *op = reg.read().bits();
302        }
303    }
304
305    fn read_results(&mut self, outbuf: &mut [u32]) {
306        self.wait_for_idle();
307        self.read_out(outbuf);
308    }
309
310    /// Enables/disables constant time operation.
311    ///
312    /// Disabling constant time operation increases the performance of modular
313    /// exponentiation by simplifying the calculation concerning the 0 bits
314    /// of the exponent. I.e. the less the Hamming weight, the greater the
315    /// performance.
316    ///
317    /// Note: this compromises security by enabling timing-based side-channel attacks.
318    ///
319    /// For more information refer to the
320    #[doc = trm_markdown_link!("rsa")]
321    #[cfg(not(esp32))]
322    pub fn disable_constant_time(&mut self, disable: bool) {
323        self.regs()
324            .constant_time()
325            .write(|w| w.constant_time().bit(disable));
326    }
327
328    /// Enables/disables search acceleration.
329    ///
330    /// When enabled it would increase the performance of modular
331    /// exponentiation by discarding the exponent's bits before the most
332    /// significant set bit.
333    ///
334    /// Note: this compromises security by effectively decreasing the key length.
335    ///
336    /// For more information refer to the
337    #[doc = trm_markdown_link!("rsa")]
338    #[cfg(not(esp32))]
339    pub fn search_acceleration(&mut self, enable: bool) {
340        self.regs()
341            .search_enable()
342            .write(|w| w.search_enable().bit(enable));
343    }
344
345    /// Checks if the search functionality is enabled in the RSA hardware.
346    #[cfg(not(esp32))]
347    fn is_search_enabled(&mut self) -> bool {
348        self.regs()
349            .search_enable()
350            .read()
351            .search_enable()
352            .bit_is_set()
353    }
354
355    /// Sets the search position in the RSA hardware.
356    #[cfg(not(esp32))]
357    fn write_search_position(&mut self, search_position: u32) {
358        self.regs()
359            .search_pos()
360            .write(|w| unsafe { w.bits(search_position) });
361    }
362}
363
364/// Defines the input size of an RSA operation.
365pub trait RsaMode: crate::private::Sealed {
366    /// The input data type used for the operation.
367    type InputType: AsRef<[u32]> + AsMut<[u32]>;
368}
369
370/// Defines the output type of RSA multiplications.
371pub trait Multi: RsaMode {
372    /// The type of the output produced by the operation.
373    type OutputType: AsRef<[u32]> + AsMut<[u32]>;
374}
375
376/// Defines the exponentiation and multiplication lengths for RSA operations.
377pub mod operand_sizes {
378    for_each_rsa_exponentiation!(
379        ($x:literal) => {
380            paste::paste! {
381                #[doc = concat!(stringify!($x), "-bit RSA operation.")]
382                pub struct [<Op $x>];
383
384                impl crate::private::Sealed for [<Op $x>] {}
385                impl crate::rsa::RsaMode for [<Op $x>] {
386                    type InputType = [u32; $x / 32];
387                }
388            }
389        };
390    );
391
392    for_each_rsa_multiplication!(
393        ($x:literal) => {
394            impl crate::rsa::Multi for paste::paste!( [<Op $x>] ) {
395                type OutputType = [u32; $x * 2 / 32];
396            }
397        };
398    );
399}
400
401/// Support for RSA peripheral's modular exponentiation feature that could be
402/// used to find the `(base ^ exponent) mod modulus`.
403///
404/// Each operand is a little endian byte array of the same size
405pub struct RsaModularExponentiation<'a, 'd, T: RsaMode, Dm: DriverMode> {
406    rsa: &'a mut Rsa<'d, Dm>,
407    phantom: PhantomData<T>,
408}
409
410impl<'a, 'd, T: RsaMode, Dm: DriverMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Dm>
411where
412    T: RsaMode<InputType = [u32; N]>,
413{
414    /// Creates an instance of `RsaModularExponentiation`.
415    ///
416    /// `m_prime` could be calculated using `-(modular multiplicative inverse of
417    /// modulus) mod 2^32`.
418    ///
419    /// For more information refer to the
420    #[doc = trm_markdown_link!("rsa")]
421    pub fn new(
422        rsa: &'a mut Rsa<'d, Dm>,
423        exponent: &T::InputType,
424        modulus: &T::InputType,
425        m_prime: u32,
426    ) -> Self {
427        Self::write_mode(rsa);
428        rsa.write_operand_b(exponent);
429        rsa.write_modulus(modulus);
430        rsa.write_mprime(m_prime);
431
432        #[cfg(not(esp32))]
433        if rsa.is_search_enabled() {
434            rsa.write_search_position(Self::find_search_pos(exponent));
435        }
436
437        Self {
438            rsa,
439            phantom: PhantomData,
440        }
441    }
442
443    fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
444        self.rsa.write_operand_a(base);
445        self.rsa.write_r(r);
446    }
447
448    /// Starts the modular exponentiation operation.
449    ///
450    /// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
451    ///
452    /// For more information refer to the
453    #[doc = trm_markdown_link!("rsa")]
454    pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
455        self.set_up_exponentiation(base, r);
456        self.rsa.start_modexp();
457    }
458
459    /// Reads the result to the given buffer.
460    ///
461    /// This is a blocking function: it waits for the RSA operation to complete,
462    /// then reads the results into the provided buffer. `start_exponentiation` must be
463    /// called before calling this function.
464    pub fn read_results(&mut self, outbuf: &mut T::InputType) {
465        self.rsa.read_results(outbuf);
466    }
467
468    #[cfg(not(esp32))]
469    fn find_search_pos(exponent: &T::InputType) -> u32 {
470        for (i, byte) in exponent.iter().rev().enumerate() {
471            if *byte == 0 {
472                continue;
473            }
474            return (exponent.len() * 32) as u32 - (byte.leading_zeros() + i as u32 * 32) - 1;
475        }
476        0
477    }
478
479    /// Sets the modular exponentiation mode for the RSA hardware.
480    fn write_mode(rsa: &mut Rsa<'d, Dm>) {
481        rsa.write_modexp_mode(N as u32 / WORDS_PER_INCREMENT - 1);
482    }
483}
484
485/// Support for RSA peripheral's modular multiplication feature that could be
486/// used to find the `(operand a * operand b) mod modulus`.
487///
488/// Each operand is a little endian byte array of the same size
489pub struct RsaModularMultiplication<'a, 'd, T, Dm>
490where
491    T: RsaMode,
492    Dm: DriverMode,
493{
494    rsa: &'a mut Rsa<'d, Dm>,
495    phantom: PhantomData<T>,
496}
497
498impl<'a, 'd, T, Dm, const N: usize> RsaModularMultiplication<'a, 'd, T, Dm>
499where
500    T: RsaMode<InputType = [u32; N]>,
501    Dm: DriverMode,
502{
503    /// Creates an instance of `RsaModularMultiplication`.
504    ///
505    /// - `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
506    /// - `m_prime` can be calculated using `-(modular multiplicative inverse of modulus) mod 2^32`.
507    ///
508    /// For more information refer to the
509    #[doc = trm_markdown_link!("rsa")]
510    pub fn new(
511        rsa: &'a mut Rsa<'d, Dm>,
512        operand_a: &T::InputType,
513        modulus: &T::InputType,
514        r: &T::InputType,
515        m_prime: u32,
516    ) -> Self {
517        rsa.write_multi_mode(N as u32 / WORDS_PER_INCREMENT - 1, true);
518
519        rsa.write_mprime(m_prime);
520        rsa.write_modulus(modulus);
521        rsa.write_operand_a(operand_a);
522        rsa.write_r(r);
523
524        Self {
525            rsa,
526            phantom: PhantomData,
527        }
528    }
529
530    /// Starts the modular multiplication operation.
531    ///
532    /// For more information refer to the
533    #[doc = trm_markdown_link!("rsa")]
534    pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) {
535        self.set_up_modular_multiplication(operand_b);
536        self.rsa.start_modmulti();
537    }
538
539    /// Reads the result to the given buffer.
540    ///
541    /// This is a blocking function: it waits for the RSA operation to complete,
542    /// then reads the results into the provided buffer. `start_modular_multiplication` must be
543    /// called before calling this function.
544    pub fn read_results(&mut self, outbuf: &mut T::InputType) {
545        self.rsa.read_results(outbuf);
546    }
547
548    fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
549        if cfg!(esp32) {
550            self.rsa.start_multi();
551            self.rsa.wait_for_idle();
552
553            self.rsa.write_operand_a(operand_b);
554        } else {
555            self.rsa.write_operand_b(operand_b);
556        }
557    }
558}
559
560/// Support for RSA peripheral's large number multiplication feature that could
561/// be used to find the `operand a * operand b`.
562///
563/// Each operand is a little endian byte array of the same size
564pub struct RsaMultiplication<'a, 'd, T, Dm>
565where
566    T: RsaMode + Multi,
567    Dm: DriverMode,
568{
569    rsa: &'a mut Rsa<'d, Dm>,
570    phantom: PhantomData<T>,
571}
572
573impl<'a, 'd, T, Dm, const N: usize> RsaMultiplication<'a, 'd, T, Dm>
574where
575    T: RsaMode<InputType = [u32; N]>,
576    T: Multi,
577    Dm: DriverMode,
578{
579    /// Creates an instance of `RsaMultiplication`.
580    pub fn new(rsa: &'a mut Rsa<'d, Dm>, operand_a: &T::InputType) -> Self {
581        // Non-modular multiplication result is twice as wide as its operands.
582        rsa.write_multi_mode(2 * N as u32 / WORDS_PER_INCREMENT - 1, false);
583        rsa.write_operand_a(operand_a);
584
585        Self {
586            rsa,
587            phantom: PhantomData,
588        }
589    }
590
591    /// Starts the multiplication operation.
592    pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
593        self.set_up_multiplication(operand_b);
594        self.rsa.start_multi();
595    }
596
597    /// Reads the result to the given buffer.
598    ///
599    /// This is a blocking function: it waits for the RSA operation to complete,
600    /// then reads the results into the provided buffer. `start_multiplication` must be
601    /// called before calling this function.
602    pub fn read_results<const O: usize>(&mut self, outbuf: &mut T::OutputType)
603    where
604        T: Multi<OutputType = [u32; O]>,
605    {
606        self.rsa.read_results(outbuf);
607    }
608
609    fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
610        self.rsa.write_multi_operand_b(operand_b);
611    }
612}
613
614static WAKER: AtomicWaker = AtomicWaker::new();
615// TODO: this should only be needed for ESP32
616static SIGNALED: AtomicBool = AtomicBool::new(false);
617
618/// `Future` that waits for the RSA operation to complete.
619#[must_use = "futures do nothing unless you `.await` or poll them"]
620struct RsaFuture<'a, 'd> {
621    driver: &'a Rsa<'d, Async>,
622}
623
624impl<'a, 'd> RsaFuture<'a, 'd> {
625    fn new(driver: &'a Rsa<'d, Async>) -> Self {
626        SIGNALED.store(false, Ordering::Relaxed);
627
628        driver.internal_enable_disable_interrupt(true);
629
630        Self { driver }
631    }
632
633    fn is_done(&self) -> bool {
634        SIGNALED.load(Ordering::Acquire)
635    }
636}
637
638impl Drop for RsaFuture<'_, '_> {
639    fn drop(&mut self) {
640        self.driver.internal_enable_disable_interrupt(false);
641    }
642}
643
644impl core::future::Future for RsaFuture<'_, '_> {
645    type Output = ();
646
647    fn poll(
648        self: core::pin::Pin<&mut Self>,
649        cx: &mut core::task::Context<'_>,
650    ) -> core::task::Poll<Self::Output> {
651        WAKER.register(cx.waker());
652        if self.is_done() {
653            Poll::Ready(())
654        } else {
655            Poll::Pending
656        }
657    }
658}
659
660impl<T: RsaMode, const N: usize> RsaModularExponentiation<'_, '_, T, Async>
661where
662    T: RsaMode<InputType = [u32; N]>,
663{
664    /// Asynchronously performs an RSA modular exponentiation operation.
665    pub async fn exponentiation(
666        &mut self,
667        base: &T::InputType,
668        r: &T::InputType,
669        outbuf: &mut T::InputType,
670    ) {
671        self.set_up_exponentiation(base, r);
672        let fut = RsaFuture::new(self.rsa);
673        self.rsa.start_modexp();
674        fut.await;
675        self.rsa.read_out(outbuf);
676    }
677}
678
679impl<T: RsaMode, const N: usize> RsaModularMultiplication<'_, '_, T, Async>
680where
681    T: RsaMode<InputType = [u32; N]>,
682{
683    /// Asynchronously performs an RSA modular multiplication operation.
684    pub async fn modular_multiplication(
685        &mut self,
686        operand_b: &T::InputType,
687        outbuf: &mut T::InputType,
688    ) {
689        if cfg!(esp32) {
690            let fut = RsaFuture::new(self.rsa);
691            self.rsa.start_multi();
692            fut.await;
693
694            self.rsa.write_operand_a(operand_b);
695        } else {
696            self.set_up_modular_multiplication(operand_b);
697        }
698
699        let fut = RsaFuture::new(self.rsa);
700        self.rsa.start_modmulti();
701        fut.await;
702        self.rsa.read_out(outbuf);
703    }
704}
705
706impl<T: RsaMode + Multi, const N: usize> RsaMultiplication<'_, '_, T, Async>
707where
708    T: RsaMode<InputType = [u32; N]>,
709{
710    /// Asynchronously performs an RSA multiplication operation.
711    pub async fn multiplication<const O: usize>(
712        &mut self,
713        operand_b: &T::InputType,
714        outbuf: &mut T::OutputType,
715    ) where
716        T: Multi<OutputType = [u32; O]>,
717    {
718        self.set_up_multiplication(operand_b);
719        let fut = RsaFuture::new(self.rsa);
720        self.rsa.start_multi();
721        fut.await;
722        self.rsa.read_out(outbuf);
723    }
724}
725
726#[handler]
727/// Interrupt handler for RSA.
728pub(super) fn rsa_interrupt_handler() {
729    let rsa = RSA::regs();
730    SIGNALED.store(true, Ordering::Release);
731    cfg_if::cfg_if! {
732        if #[cfg(esp32)] {
733            rsa.interrupt().write(|w| w.interrupt().set_bit());
734        } else  {
735            rsa.int_clr().write(|w| w.int_clr().set_bit());
736        }
737    }
738
739    WAKER.wake();
740}
741
742static RSA_WORK_QUEUE: WorkQueue<RsaWorkItem> = WorkQueue::new();
743const RSA_VTABLE: VTable<RsaWorkItem> = VTable {
744    post: |driver, item| {
745        // Start processing immediately.
746        let driver = unsafe { RsaBackend::from_raw(driver) };
747        Some(driver.process_item(item))
748    },
749    poll: |driver, item| {
750        let driver = unsafe { RsaBackend::from_raw(driver) };
751        driver.process_item(item)
752    },
753    cancel: |driver, item| {
754        let driver = unsafe { RsaBackend::from_raw(driver) };
755        driver.cancel(item)
756    },
757    stop: |driver| {
758        let driver = unsafe { RsaBackend::from_raw(driver) };
759        driver.deinitialize()
760    },
761};
762
763#[derive(Default)]
764enum RsaBackendState<'d> {
765    #[default]
766    Idle,
767    Initializing(Rsa<'d, Blocking>),
768    Ready(Rsa<'d, Blocking>),
769    #[cfg(esp32)]
770    ModularMultiplicationRoundOne(Rsa<'d, Blocking>),
771    Processing(Rsa<'d, Blocking>),
772}
773
774#[procmacros::doc_replace]
775/// RSA processing backend.
776///
777/// The backend processes work items placed in the RSA work queue. The backend needs to be created
778/// and started for operations to be processed. This allows you to perform operations on the RSA
779/// accelerator without carrying around the peripheral singleton, or the driver.
780///
781/// The [`RsaContext`] struct can enqueue work items that this backend will process.
782///
783/// ## Example
784///
785/// ```rust, no_run
786/// # {before_snippet}
787/// use esp_hal::rsa::{RsaBackend, RsaContext, operand_sizes::Op512};
788/// #
789/// let mut rsa_backend = RsaBackend::new(peripherals.RSA);
790/// let _driver = rsa_backend.start();
791///
792/// async fn perform_512bit_big_number_multiplication(
793///     operand_a: &[u32; 16],
794///     operand_b: &[u32; 16],
795///     result: &mut [u32; 32],
796/// ) {
797///     let mut rsa = RsaContext::new();
798///
799///     let mut handle = rsa.multiply::<Op512>(operand_a, operand_b, result);
800///     handle.wait().await;
801/// }
802/// # {after_snippet}
803/// ```
804pub struct RsaBackend<'d> {
805    peri: RSA<'d>,
806    state: RsaBackendState<'d>,
807}
808
809impl<'d> RsaBackend<'d> {
810    #[procmacros::doc_replace]
811    /// Creates a new RSA backend.
812    ///
813    /// ## Example
814    ///
815    /// ```rust, no_run
816    /// # {before_snippet}
817    /// use esp_hal::rsa::RsaBackend;
818    /// #
819    /// let mut rsa = RsaBackend::new(peripherals.RSA);
820    /// # {after_snippet}
821    /// ```
822    pub fn new(rsa: RSA<'d>) -> Self {
823        Self {
824            peri: rsa,
825            state: RsaBackendState::Idle,
826        }
827    }
828
829    #[procmacros::doc_replace]
830    /// Registers the RSA driver to process RSA operations.
831    ///
832    /// The driver stops operating when the returned object is dropped.
833    ///
834    /// ## Example
835    ///
836    /// ```rust, no_run
837    /// # {before_snippet}
838    /// use esp_hal::rsa::RsaBackend;
839    /// #
840    /// let mut rsa = RsaBackend::new(peripherals.RSA);
841    /// // Start the backend, which allows processing RSA operations.
842    /// let _backend = rsa.start();
843    /// # {after_snippet}
844    /// ```
845    pub fn start(&mut self) -> RsaWorkQueueDriver<'_, 'd> {
846        RsaWorkQueueDriver {
847            inner: WorkQueueDriver::new(self, RSA_VTABLE, &RSA_WORK_QUEUE),
848        }
849    }
850
851    // WorkQueue callbacks. They may run in any context.
852
853    unsafe fn from_raw<'any>(ptr: NonNull<()>) -> &'any mut Self {
854        unsafe { ptr.cast::<RsaBackend<'_>>().as_mut() }
855    }
856
857    fn process_item(&mut self, item: &mut RsaWorkItem) -> work_queue::Poll {
858        match core::mem::take(&mut self.state) {
859            RsaBackendState::Idle => {
860                let driver = Rsa {
861                    rsa: unsafe { self.peri.clone_unchecked() },
862                    phantom: PhantomData,
863                    _guard: GenericPeripheralGuard::new(),
864                };
865                self.state = RsaBackendState::Initializing(driver);
866                work_queue::Poll::Pending(true)
867            }
868            RsaBackendState::Initializing(mut rsa) => {
869                // Wait for the peripheral to finish initializing. Ideally we need a way to
870                // instruct the work queue to wake the polling task immediately.
871                self.state = if rsa.ready() {
872                    rsa.set_interrupt_handler(rsa_work_queue_handler);
873                    rsa.enable_disable_interrupt(true);
874                    RsaBackendState::Ready(rsa)
875                } else {
876                    RsaBackendState::Initializing(rsa)
877                };
878                work_queue::Poll::Pending(true)
879            }
880            RsaBackendState::Ready(mut rsa) => {
881                #[cfg(not(esp32))]
882                {
883                    rsa.disable_constant_time(!item.constant_time);
884                    rsa.search_acceleration(item.search_acceleration);
885                }
886
887                match item.operation {
888                    RsaOperation::Multiplication { x, y } => {
889                        let n = x.len() as u32;
890                        rsa.write_operand_a(unsafe { x.as_ref() });
891
892                        // Non-modular multiplication result is twice as wide as its operands.
893                        rsa.write_multi_mode(2 * n / WORDS_PER_INCREMENT - 1, false);
894                        rsa.write_multi_operand_b(unsafe { y.as_ref() });
895                        rsa.start_multi();
896                    }
897
898                    RsaOperation::ModularMultiplication {
899                        x,
900                        #[cfg(not(esp32))]
901                        y,
902                        m,
903                        m_prime,
904                        r: r_inv,
905                        ..
906                    } => {
907                        let n = x.len() as u32;
908                        rsa.write_operand_a(unsafe { x.as_ref() });
909
910                        rsa.write_multi_mode(n / WORDS_PER_INCREMENT - 1, true);
911
912                        #[cfg(not(esp32))]
913                        rsa.write_operand_b(unsafe { y.as_ref() });
914
915                        rsa.write_modulus(unsafe { m.as_ref() });
916                        rsa.write_mprime(m_prime);
917                        rsa.write_r(unsafe { r_inv.as_ref() });
918
919                        rsa.start_modmulti();
920
921                        #[cfg(esp32)]
922                        {
923                            // ESP32 requires a two-step process where Y needs to be written to the
924                            // X memory.
925                            self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
926
927                            return work_queue::Poll::Pending(false);
928                        }
929                    }
930                    RsaOperation::ModularExponentiation {
931                        x,
932                        y,
933                        m,
934                        m_prime,
935                        r_inv,
936                    } => {
937                        let n = x.len() as u32;
938                        rsa.write_operand_a(unsafe { x.as_ref() });
939
940                        rsa.write_modexp_mode(n / WORDS_PER_INCREMENT - 1);
941                        rsa.write_operand_b(unsafe { y.as_ref() });
942                        rsa.write_modulus(unsafe { m.as_ref() });
943                        rsa.write_mprime(m_prime);
944                        rsa.write_r(unsafe { r_inv.as_ref() });
945
946                        #[cfg(not(esp32))]
947                        if item.search_acceleration {
948                            fn find_search_pos(exponent: &[u32]) -> u32 {
949                                for (i, byte) in exponent.iter().rev().enumerate() {
950                                    if *byte == 0 {
951                                        continue;
952                                    }
953                                    return (exponent.len() * 32) as u32
954                                        - (byte.leading_zeros() + i as u32 * 32)
955                                        - 1;
956                                }
957                                0
958                            }
959                            rsa.write_search_position(find_search_pos(unsafe { y.as_ref() }));
960                        }
961
962                        rsa.start_modexp();
963                    }
964                }
965
966                self.state = RsaBackendState::Processing(rsa);
967
968                work_queue::Poll::Pending(false)
969            }
970
971            #[cfg(esp32)]
972            RsaBackendState::ModularMultiplicationRoundOne(mut rsa) => {
973                if rsa.is_idle() {
974                    let RsaOperation::ModularMultiplication { y, .. } = item.operation else {
975                        unreachable!();
976                    };
977
978                    // Y needs to be written to the X memory.
979                    rsa.write_operand_a(unsafe { y.as_ref() });
980                    rsa.start_modmulti();
981
982                    self.state = RsaBackendState::Processing(rsa);
983                } else {
984                    // Wait for the operation to complete
985                    self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
986                }
987                work_queue::Poll::Pending(false)
988            }
989
990            RsaBackendState::Processing(rsa) => {
991                if rsa.is_idle() {
992                    rsa.read_out(unsafe { item.result.as_mut() });
993
994                    self.state = RsaBackendState::Ready(rsa);
995                    work_queue::Poll::Ready(Status::Completed)
996                } else {
997                    self.state = RsaBackendState::Processing(rsa);
998                    work_queue::Poll::Pending(false)
999                }
1000            }
1001        }
1002    }
1003
1004    fn cancel(&mut self, _item: &mut RsaWorkItem) {
1005        // Drop the driver to reset it. We don't read the result, so the work item remains
1006        // unchanged, effectively cancelling it.
1007        self.state = RsaBackendState::Idle;
1008    }
1009
1010    fn deinitialize(&mut self) {
1011        self.state = RsaBackendState::Idle;
1012    }
1013}
1014
1015/// An active work queue driver.
1016///
1017/// This object must be kept around, otherwise RSA operations will never complete.
1018///
1019/// For a usage example, see [`RsaBackend`].
1020pub struct RsaWorkQueueDriver<'t, 'd> {
1021    inner: WorkQueueDriver<'t, RsaBackend<'d>, RsaWorkItem>,
1022}
1023
1024impl<'t, 'd> RsaWorkQueueDriver<'t, 'd> {
1025    /// Finishes processing the current work queue item, then stops the driver.
1026    pub fn stop(self) -> impl Future<Output = ()> {
1027        self.inner.stop()
1028    }
1029}
1030
1031struct RsaWorkItem {
1032    // Acceleration options
1033    #[cfg(not(esp32))]
1034    search_acceleration: bool,
1035    #[cfg(not(esp32))]
1036    constant_time: bool,
1037
1038    // The operation to execute.
1039    operation: RsaOperation,
1040    result: NonNull<[u32]>,
1041}
1042
1043unsafe impl Sync for RsaWorkItem {}
1044unsafe impl Send for RsaWorkItem {}
1045
1046enum RsaOperation {
1047    // Z = X * Y
1048    // len(Z) = len(X) + len(Y)
1049    Multiplication {
1050        x: NonNull<[u32]>,
1051        y: NonNull<[u32]>,
1052    },
1053    // Z = X * Y mod M
1054    ModularMultiplication {
1055        x: NonNull<[u32]>,
1056        y: NonNull<[u32]>,
1057        m: NonNull<[u32]>,
1058        r: NonNull<[u32]>,
1059        m_prime: u32,
1060    },
1061    // Z = X ^ Y mod M
1062    ModularExponentiation {
1063        x: NonNull<[u32]>,
1064        y: NonNull<[u32]>,
1065        m: NonNull<[u32]>,
1066        r_inv: NonNull<[u32]>,
1067        m_prime: u32,
1068    },
1069}
1070
1071#[handler]
1072#[ram]
1073fn rsa_work_queue_handler() {
1074    if !RSA_WORK_QUEUE.process() {
1075        // The queue may indicate that it needs to be polled again. In this case, we do not clear
1076        // the interrupt bit, which causes the interrupt to be re-handled.
1077        cfg_if::cfg_if! {
1078            if #[cfg(esp32)] {
1079                RSA::regs().interrupt().write(|w| w.interrupt().set_bit());
1080            } else {
1081                RSA::regs().int_clr().write(|w| w.int_clr().set_bit());
1082            }
1083        }
1084    }
1085}
1086
1087/// An RSA work queue user.
1088///
1089/// This object allows performing [big number multiplication][Self::multiply], [big number modular
1090/// multiplication][Self::modular_multiply] and [big number modular
1091/// exponentiation][Self::modular_exponentiate] with hardware acceleration. To perform these
1092/// operations, the [`RsaBackend`] must be started, otherwise these operations will never complete.
1093#[cfg_attr(
1094    not(esp32),
1095    doc = " \nThe context is created with a secure configuration by default. You can enable hardware acceleration
1096    options using [enable_search_acceleration][Self::enable_search_acceleration] and
1097    [enable_acceleration][Self::enable_acceleration] when appropriate."
1098)]
1099pub struct RsaContext {
1100    frontend: WorkQueueFrontend<RsaWorkItem>,
1101}
1102
1103impl Default for RsaContext {
1104    fn default() -> Self {
1105        Self::new()
1106    }
1107}
1108
1109impl RsaContext {
1110    /// Creates a new context.
1111    pub fn new() -> Self {
1112        Self {
1113            frontend: WorkQueueFrontend::new(RsaWorkItem {
1114                #[cfg(not(esp32))]
1115                search_acceleration: false,
1116                #[cfg(not(esp32))]
1117                constant_time: true,
1118                operation: RsaOperation::Multiplication {
1119                    x: NonNull::from(&[]),
1120                    y: NonNull::from(&[]),
1121                },
1122                result: NonNull::from(&mut []),
1123            }),
1124        }
1125    }
1126
1127    #[cfg(not(esp32))]
1128    /// Enables search acceleration.
1129    ///
1130    /// When enabled it would increase the performance of modular
1131    /// exponentiation by discarding the exponent's bits before the most
1132    /// significant set bit.
1133    ///
1134    /// > ⚠️ Note: this compromises security by effectively decreasing the key length.
1135    ///
1136    /// For more information refer to the
1137    #[doc = trm_markdown_link!("rsa")]
1138    pub fn enable_search_acceleration(&mut self) {
1139        self.frontend.data_mut().search_acceleration = true;
1140    }
1141
1142    #[cfg(not(esp32))]
1143    /// Enables acceleration by disabling constant time operation.
1144    ///
1145    /// Disabling constant time operation increases the performance of modular
1146    /// exponentiation by simplifying the calculation concerning the 0 bits
1147    /// of the exponent. I.e. the less the Hamming weight, the greater the
1148    /// performance.
1149    ///
1150    /// > ⚠️ Note: this compromises security by enabling timing-based side-channel attacks.
1151    ///
1152    /// For more information refer to the
1153    #[doc = trm_markdown_link!("rsa")]
1154    pub fn enable_acceleration(&mut self) {
1155        self.frontend.data_mut().constant_time = false;
1156    }
1157
1158    fn post(&mut self) -> RsaHandle<'_> {
1159        RsaHandle(self.frontend.post(&RSA_WORK_QUEUE))
1160    }
1161
1162    #[procmacros::doc_replace]
1163    /// Starts a modular exponentiation operation, performing `Z = X ^ Y mod M`.
1164    ///
1165    /// Software needs to pre-calculate the following values:
1166    ///
1167    /// - `r`: `2 ^ ( bitlength * 2 ) mod M`.
1168    /// - `m_prime` can be calculated using `-(modular multiplicative inverse of M) mod 2^32`.
1169    ///
1170    /// It is relatively easy to calculate these values using the `crypto-bigint` crate:
1171    ///
1172    /// ```rust,no_run
1173    /// # {before_snippet}
1174    /// use crypto_bigint::{U512, Uint};
1175    /// const fn compute_r(modulus: &U512) -> U512 {
1176    ///     let mut d = [0_u32; U512::LIMBS * 2 + 1];
1177    ///     d[d.len() - 1] = 1;
1178    ///     let d = Uint::from_words(d);
1179    ///     d.const_rem(&modulus.resize()).0.resize()
1180    /// }
1181    ///
1182    /// const fn compute_mprime(modulus: &U512) -> u32 {
1183    ///     let m_inv = modulus.inv_mod2k(32).to_words()[0];
1184    ///     (-1 * m_inv as i64 & (u32::MAX as i64)) as u32
1185    /// }
1186    ///
1187    /// // Inputs
1188    /// const X: U512 = Uint::from_be_hex(
1189    ///     "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
1190    ///    7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878",
1191    /// );
1192    /// const Y: U512 = Uint::from_be_hex(
1193    ///     "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\
1194    ///    6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51",
1195    /// );
1196    /// const M: U512 = Uint::from_be_hex(
1197    ///     "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\
1198    ///    67aeafa9ba25feb8712f188cb139b7d9b9af1c361",
1199    /// );
1200    ///
1201    /// // Values derived using the functions we defined above:
1202    /// let r = compute_r(&M);
1203    /// let mprime = compute_mprime(&M);
1204    ///
1205    /// use esp_hal::rsa::{RsaContext, operand_sizes::Op512};
1206    ///
1207    /// // Now perform the actual computation:
1208    /// let mut rsa = RsaContext::new();
1209    /// let mut outbuf = [0; 16];
1210    /// let mut handle = rsa.modular_multiply::<Op512>(
1211    ///     X.as_words(),
1212    ///     Y.as_words(),
1213    ///     M.as_words(),
1214    ///     r.as_words(),
1215    ///     mprime,
1216    ///     &mut outbuf,
1217    /// );
1218    /// handle.wait_blocking();
1219    /// # {after_snippet}
1220    /// ```
1221    ///
1222    /// The calculation is done asynchronously. This function returns an [`RsaHandle`] that can be
1223    /// used to poll the status of the calculation, to wait for it to finish, or to cancel the
1224    /// operation (by dropping the handle).
1225    ///
1226    /// When the operation is completed, the result will be stored in `result`.
1227    pub fn modular_exponentiate<'t, OP>(
1228        &'t mut self,
1229        x: &'t OP::InputType,
1230        y: &'t OP::InputType,
1231        m: &'t OP::InputType,
1232        r: &'t OP::InputType,
1233        m_prime: u32,
1234        result: &'t mut OP::InputType,
1235    ) -> RsaHandle<'t>
1236    where
1237        OP: RsaMode,
1238    {
1239        self.frontend.data_mut().operation = RsaOperation::ModularExponentiation {
1240            x: NonNull::from(x.as_ref()),
1241            y: NonNull::from(y.as_ref()),
1242            m: NonNull::from(m.as_ref()),
1243            r_inv: NonNull::from(r.as_ref()),
1244            m_prime,
1245        };
1246        self.frontend.data_mut().result = NonNull::from(result.as_mut());
1247        self.post()
1248    }
1249
1250    /// Starts a modular multiplication operation, performing `Z = X * Y mod M`.
1251    ///
1252    /// Software needs to pre-calculate the following values:
1253    ///
1254    /// - `r`: `2 ^ ( bitlength * 2 ) mod M`.
1255    /// - `m_prime` can be calculated using `-(modular multiplicative inverse of M) mod 2^32`.
1256    ///
1257    /// For an example how these values can be calculated and used, see
1258    /// [Self::modular_exponentiate].
1259    ///
1260    /// The calculation is done asynchronously. This function returns an [`RsaHandle`] that can be
1261    /// used to poll the status of the calculation, to wait for it to finish, or to cancel the
1262    /// operation (by dropping the handle).
1263    ///
1264    /// When the operation is completed, the result will be stored in `result`.
1265    pub fn modular_multiply<'t, OP>(
1266        &'t mut self,
1267        x: &'t OP::InputType,
1268        y: &'t OP::InputType,
1269        m: &'t OP::InputType,
1270        r: &'t OP::InputType,
1271        m_prime: u32,
1272        result: &'t mut OP::InputType,
1273    ) -> RsaHandle<'t>
1274    where
1275        OP: RsaMode,
1276    {
1277        self.frontend.data_mut().operation = RsaOperation::ModularMultiplication {
1278            x: NonNull::from(x.as_ref()),
1279            y: NonNull::from(y.as_ref()),
1280            m: NonNull::from(m.as_ref()),
1281            r: NonNull::from(r.as_ref()),
1282            m_prime,
1283        };
1284        self.frontend.data_mut().result = NonNull::from(result.as_mut());
1285        self.post()
1286    }
1287
1288    #[procmacros::doc_replace]
1289    /// Starts a multiplication operation, performing `Z = X * Y`.
1290    ///
1291    /// The calculation is done asynchronously. This function returns an [`RsaHandle`] that can be
1292    /// used to poll the status of the calculation, to wait for it to finish, or to cancel the
1293    /// operation (by dropping the handle).
1294    ///
1295    /// When the operation is completed, the result will be stored in `result`. The `result` is
1296    /// twice as wide as the inputs.
1297    ///
1298    /// ## Example
1299    ///
1300    /// ```rust,no_run
1301    /// # {before_snippet}
1302    ///
1303    /// // Inputs
1304    /// # let x: [u32; 16] = [0; 16];
1305    /// # let y: [u32; 16] = [0; 16];
1306    /// // let x: [u32; 16] = [...];
1307    /// // let y: [u32; 16] = [...];
1308    /// let mut outbuf = [0; 32];
1309    ///
1310    /// use esp_hal::rsa::{RsaContext, operand_sizes::Op512};
1311    ///
1312    /// // Now perform the actual computation:
1313    /// let mut rsa = RsaContext::new();
1314    /// let mut handle = rsa.multiply::<Op512>(&x, &y, &mut outbuf);
1315    /// handle.wait_blocking();
1316    /// # {after_snippet}
1317    /// ```
1318    pub fn multiply<'t, OP>(
1319        &'t mut self,
1320        x: &'t OP::InputType,
1321        y: &'t OP::InputType,
1322        result: &'t mut OP::OutputType,
1323    ) -> RsaHandle<'t>
1324    where
1325        OP: Multi,
1326    {
1327        self.frontend.data_mut().operation = RsaOperation::Multiplication {
1328            x: NonNull::from(x.as_ref()),
1329            y: NonNull::from(y.as_ref()),
1330        };
1331        self.frontend.data_mut().result = NonNull::from(result.as_mut());
1332        self.post()
1333    }
1334}
1335
1336/// The handle to the pending RSA operation.
1337pub struct RsaHandle<'t>(work_queue::Handle<'t, RsaWorkItem>);
1338
1339impl RsaHandle<'_> {
1340    /// Polls the status of the work item.
1341    #[inline]
1342    pub fn poll(&mut self) -> bool {
1343        self.0.poll()
1344    }
1345
1346    /// Blocks until the work item to be processed.
1347    #[inline]
1348    pub fn wait_blocking(mut self) {
1349        while !self.poll() {}
1350    }
1351
1352    /// Waits for the work item to be processed.
1353    #[inline]
1354    pub fn wait(&mut self) -> impl Future<Output = Status> {
1355        self.0.wait()
1356    }
1357}