1use core::{marker::PhantomData, ptr::NonNull, task::Poll};
17
18use portable_atomic::{AtomicBool, Ordering};
19use procmacros::{handler, ram};
20
21use crate::{
22 Async,
23 Blocking,
24 DriverMode,
25 asynch::AtomicWaker,
26 interrupt::InterruptHandler,
27 pac,
28 peripherals::{Interrupt, RSA},
29 system::{Cpu, GenericPeripheralGuard, Peripheral as PeripheralEnable},
30 trm_markdown_link,
31 work_queue::{self, Status, VTable, WorkQueue, WorkQueueDriver, WorkQueueFrontend},
32};
33
34pub struct Rsa<'d, Dm: DriverMode> {
36 rsa: RSA<'d>,
37 phantom: PhantomData<Dm>,
38 #[cfg(not(esp32))]
39 _memory_guard: RsaMemoryPowerGuard,
40 _guard: GenericPeripheralGuard<{ PeripheralEnable::Rsa as u8 }>,
41}
42
43const WORDS_PER_INCREMENT: u32 = property!("rsa.size_increment") / 32;
52
53#[cfg(not(esp32))]
54struct RsaMemoryPowerGuard;
55
56#[cfg(not(esp32))]
57impl RsaMemoryPowerGuard {
58 fn new() -> Self {
59 crate::peripherals::SYSTEM::regs()
60 .rsa_pd_ctrl()
61 .modify(|_, w| {
62 w.rsa_mem_force_pd().clear_bit();
63 w.rsa_mem_force_pu().set_bit();
64 w.rsa_mem_pd().clear_bit()
65 });
66 Self
67 }
68}
69
70#[cfg(not(esp32))]
71impl Drop for RsaMemoryPowerGuard {
72 fn drop(&mut self) {
73 crate::peripherals::SYSTEM::regs()
74 .rsa_pd_ctrl()
75 .modify(|_, w| {
76 w.rsa_mem_force_pd().clear_bit();
77 w.rsa_mem_force_pu().clear_bit();
78 w.rsa_mem_pd().set_bit()
79 });
80 }
81}
82
83impl<'d> Rsa<'d, Blocking> {
84 pub fn new(rsa: RSA<'d>) -> Self {
88 let guard = GenericPeripheralGuard::new();
89
90 let this = Self {
91 rsa,
92 phantom: PhantomData,
93 #[cfg(not(esp32))]
94 _memory_guard: RsaMemoryPowerGuard::new(),
95 _guard: guard,
96 };
97
98 while !this.ready() {}
99
100 this
101 }
102
103 pub fn into_async(mut self) -> Rsa<'d, Async> {
105 self.set_interrupt_handler(rsa_interrupt_handler);
106 self.enable_disable_interrupt(true);
107
108 Rsa {
109 rsa: self.rsa,
110 phantom: PhantomData,
111 #[cfg(not(esp32))]
112 _memory_guard: self._memory_guard,
113 _guard: self._guard,
114 }
115 }
116
117 pub fn enable_disable_interrupt(&mut self, enable: bool) {
122 self.internal_enable_disable_interrupt(enable);
123 }
124
125 #[instability::unstable]
130 pub fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
131 self.rsa.disable_peri_interrupt_on_all_cores();
132 self.rsa.bind_peri_interrupt(handler);
133 }
134}
135
136impl crate::private::Sealed for Rsa<'_, Blocking> {}
137
138#[instability::unstable]
139impl crate::interrupt::InterruptConfigurable for Rsa<'_, Blocking> {
140 fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
141 self.set_interrupt_handler(handler);
142 }
143}
144
145impl<'d> Rsa<'d, Async> {
146 pub fn into_blocking(self) -> Rsa<'d, Blocking> {
148 self.internal_enable_disable_interrupt(false);
149 self.rsa.disable_peri_interrupt_on_all_cores();
150
151 crate::interrupt::disable(Cpu::current(), Interrupt::RSA);
152 Rsa {
153 rsa: self.rsa,
154 phantom: PhantomData,
155 #[cfg(not(esp32))]
156 _memory_guard: self._memory_guard,
157 _guard: self._guard,
158 }
159 }
160}
161
162impl<'d, Dm: DriverMode> Rsa<'d, Dm> {
163 fn internal_enable_disable_interrupt(&self, enable: bool) {
164 cfg_if::cfg_if! {
165 if #[cfg(esp32)] {
166 self.regs().interrupt().write(|w| w.interrupt().bit(enable));
168 } else {
169 self.regs().int_ena().write(|w| w.int_ena().bit(enable));
170 }
171 }
172 }
173
174 fn regs(&self) -> &pac::rsa::RegisterBlock {
175 self.rsa.register_block()
176 }
177
178 fn ready(&self) -> bool {
183 cfg_if::cfg_if! {
184 if #[cfg(any(esp32, esp32s2, esp32s3))] {
185 self.regs().clean().read().clean().bit_is_set()
186 } else {
187 self.regs().query_clean().read().query_clean().bit_is_set()
188 }
189 }
190 }
191
192 fn start_modexp(&self) {
194 cfg_if::cfg_if! {
195 if #[cfg(any(esp32, esp32s2, esp32s3))] {
196 self.regs()
197 .modexp_start()
198 .write(|w| w.modexp_start().set_bit());
199 } else {
200 self.regs()
201 .set_start_modexp()
202 .write(|w| w.set_start_modexp().set_bit());
203 }
204 }
205 }
206
207 fn start_multi(&self) {
209 cfg_if::cfg_if! {
210 if #[cfg(any(esp32, esp32s2, esp32s3))] {
211 self.regs().mult_start().write(|w| w.mult_start().set_bit());
212 } else {
213 self.regs()
214 .set_start_mult()
215 .write(|w| w.set_start_mult().set_bit());
216 }
217 }
218 }
219
220 fn start_modmulti(&self) {
222 cfg_if::cfg_if! {
223 if #[cfg(esp32)] {
224 self.start_multi();
226 } else if #[cfg(any(esp32s2, esp32s3))] {
227 self.regs()
228 .modmult_start()
229 .write(|w| w.modmult_start().set_bit());
230 } else {
231 self.regs()
232 .set_start_modmult()
233 .write(|w| w.set_start_modmult().set_bit());
234 }
235 }
236 }
237
238 fn clear_interrupt(&mut self) {
240 cfg_if::cfg_if! {
241 if #[cfg(esp32)] {
242 self.regs().interrupt().write(|w| w.interrupt().set_bit());
243 } else {
244 self.regs().int_clr().write(|w| w.int_clr().set_bit());
245 }
246 }
247 }
248
249 fn is_idle(&self) -> bool {
251 cfg_if::cfg_if! {
252 if #[cfg(esp32)] {
253 self.regs().interrupt().read().interrupt().bit_is_set()
254 } else if #[cfg(any(esp32s2, esp32s3))] {
255 self.regs().idle().read().idle().bit_is_set()
256 } else {
257 self.regs().query_idle().read().query_idle().bit_is_set()
258 }
259 }
260 }
261
262 fn wait_for_idle(&mut self) {
263 while !self.is_idle() {}
264 self.clear_interrupt();
265 }
266
267 fn write_multi_mode(&mut self, mode: u32, modular: bool) {
269 let mode = if cfg!(esp32) && !modular {
270 const NON_MODULAR: u32 = 8;
271 mode | NON_MODULAR
272 } else {
273 mode
274 };
275
276 cfg_if::cfg_if! {
277 if #[cfg(esp32)] {
278 self.regs().mult_mode().write(|w| unsafe { w.bits(mode) });
279 } else {
280 self.regs().mode().write(|w| unsafe { w.bits(mode) });
281 }
282 }
283 }
284
285 fn write_modexp_mode(&mut self, mode: u32) {
287 cfg_if::cfg_if! {
288 if #[cfg(esp32)] {
289 self.regs().modexp_mode().write(|w| unsafe { w.bits(mode) });
290 } else {
291 self.regs().mode().write(|w| unsafe { w.bits(mode) });
292 }
293 }
294 }
295
296 fn write_operand_b(&mut self, operand: &[u32]) {
297 for (reg, op) in self.regs().y_mem_iter().zip(operand.iter().copied()) {
298 reg.write(|w| unsafe { w.bits(op) });
299 }
300 }
301
302 fn write_modulus(&mut self, modulus: &[u32]) {
303 for (reg, op) in self.regs().m_mem_iter().zip(modulus.iter().copied()) {
304 reg.write(|w| unsafe { w.bits(op) });
305 }
306 }
307
308 fn write_mprime(&mut self, m_prime: u32) {
309 self.regs().m_prime().write(|w| unsafe { w.bits(m_prime) });
310 }
311
312 fn write_operand_a(&mut self, operand: &[u32]) {
313 for (reg, op) in self.regs().x_mem_iter().zip(operand.iter().copied()) {
314 reg.write(|w| unsafe { w.bits(op) });
315 }
316 }
317
318 fn write_multi_operand_b(&mut self, operand: &[u32]) {
319 for (reg, op) in self
320 .regs()
321 .z_mem_iter()
322 .skip(operand.len())
323 .zip(operand.iter().copied())
324 {
325 reg.write(|w| unsafe { w.bits(op) });
326 }
327 }
328
329 fn write_r(&mut self, r: &[u32]) {
330 for (reg, op) in self.regs().z_mem_iter().zip(r.iter().copied()) {
331 reg.write(|w| unsafe { w.bits(op) });
332 }
333 }
334
335 fn read_out(&self, outbuf: &mut [u32]) {
336 for (reg, op) in self.regs().z_mem_iter().zip(outbuf.iter_mut()) {
337 *op = reg.read().bits();
338 }
339 }
340
341 fn read_results(&mut self, outbuf: &mut [u32]) {
342 self.wait_for_idle();
343 self.read_out(outbuf);
344 }
345
346 #[doc = trm_markdown_link!("rsa")]
357 #[cfg(not(esp32))]
358 pub fn disable_constant_time(&mut self, disable: bool) {
359 self.regs()
360 .constant_time()
361 .write(|w| w.constant_time().bit(disable));
362 }
363
364 #[doc = trm_markdown_link!("rsa")]
374 #[cfg(not(esp32))]
375 pub fn search_acceleration(&mut self, enable: bool) {
376 self.regs()
377 .search_enable()
378 .write(|w| w.search_enable().bit(enable));
379 }
380
381 #[cfg(not(esp32))]
383 fn is_search_enabled(&mut self) -> bool {
384 self.regs()
385 .search_enable()
386 .read()
387 .search_enable()
388 .bit_is_set()
389 }
390
391 #[cfg(not(esp32))]
393 fn write_search_position(&mut self, search_position: u32) {
394 self.regs()
395 .search_pos()
396 .write(|w| unsafe { w.bits(search_position) });
397 }
398}
399
400pub trait RsaMode: crate::private::Sealed {
402 type InputType: AsRef<[u32]> + AsMut<[u32]>;
404}
405
406pub trait Multi: RsaMode {
408 type OutputType: AsRef<[u32]> + AsMut<[u32]>;
410}
411
412pub mod operand_sizes {
414 for_each_rsa_exponentiation!(
415 ($x:literal) => {
416 paste::paste! {
417 #[doc = concat!(stringify!($x), "-bit RSA operation.")]
418 pub struct [<Op $x>];
419
420 impl crate::private::Sealed for [<Op $x>] {}
421 impl crate::rsa::RsaMode for [<Op $x>] {
422 type InputType = [u32; $x / 32];
423 }
424 }
425 };
426 );
427
428 for_each_rsa_multiplication!(
429 ($x:literal) => {
430 impl crate::rsa::Multi for paste::paste!( [<Op $x>] ) {
431 type OutputType = [u32; $x * 2 / 32];
432 }
433 };
434 );
435}
436
437pub struct RsaModularExponentiation<'a, 'd, T: RsaMode, Dm: DriverMode> {
442 rsa: &'a mut Rsa<'d, Dm>,
443 phantom: PhantomData<T>,
444}
445
446impl<'a, 'd, T: RsaMode, Dm: DriverMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Dm>
447where
448 T: RsaMode<InputType = [u32; N]>,
449{
450 #[doc = trm_markdown_link!("rsa")]
457 pub fn new(
458 rsa: &'a mut Rsa<'d, Dm>,
459 exponent: &T::InputType,
460 modulus: &T::InputType,
461 m_prime: u32,
462 ) -> Self {
463 Self::write_mode(rsa);
464 rsa.write_operand_b(exponent);
465 rsa.write_modulus(modulus);
466 rsa.write_mprime(m_prime);
467
468 #[cfg(not(esp32))]
469 if rsa.is_search_enabled() {
470 rsa.write_search_position(Self::find_search_pos(exponent));
471 }
472
473 Self {
474 rsa,
475 phantom: PhantomData,
476 }
477 }
478
479 fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
480 self.rsa.write_operand_a(base);
481 self.rsa.write_r(r);
482 }
483
484 #[doc = trm_markdown_link!("rsa")]
490 pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
491 self.set_up_exponentiation(base, r);
492 self.rsa.start_modexp();
493 }
494
495 pub fn read_results(&mut self, outbuf: &mut T::InputType) {
501 self.rsa.read_results(outbuf);
502 }
503
504 #[cfg(not(esp32))]
505 fn find_search_pos(exponent: &T::InputType) -> u32 {
506 for (i, byte) in exponent.iter().rev().enumerate() {
507 if *byte == 0 {
508 continue;
509 }
510 return (exponent.len() * 32) as u32 - (byte.leading_zeros() + i as u32 * 32) - 1;
511 }
512 0
513 }
514
515 fn write_mode(rsa: &mut Rsa<'d, Dm>) {
517 rsa.write_modexp_mode(N as u32 / WORDS_PER_INCREMENT - 1);
518 }
519}
520
521pub struct RsaModularMultiplication<'a, 'd, T, Dm>
526where
527 T: RsaMode,
528 Dm: DriverMode,
529{
530 rsa: &'a mut Rsa<'d, Dm>,
531 phantom: PhantomData<T>,
532}
533
534impl<'a, 'd, T, Dm, const N: usize> RsaModularMultiplication<'a, 'd, T, Dm>
535where
536 T: RsaMode<InputType = [u32; N]>,
537 Dm: DriverMode,
538{
539 #[doc = trm_markdown_link!("rsa")]
546 pub fn new(
547 rsa: &'a mut Rsa<'d, Dm>,
548 operand_a: &T::InputType,
549 modulus: &T::InputType,
550 r: &T::InputType,
551 m_prime: u32,
552 ) -> Self {
553 rsa.write_multi_mode(N as u32 / WORDS_PER_INCREMENT - 1, true);
554
555 rsa.write_mprime(m_prime);
556 rsa.write_modulus(modulus);
557 rsa.write_operand_a(operand_a);
558 rsa.write_r(r);
559
560 Self {
561 rsa,
562 phantom: PhantomData,
563 }
564 }
565
566 #[doc = trm_markdown_link!("rsa")]
570 pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) {
571 self.set_up_modular_multiplication(operand_b);
572 self.rsa.start_modmulti();
573 }
574
575 pub fn read_results(&mut self, outbuf: &mut T::InputType) {
581 self.rsa.read_results(outbuf);
582 }
583
584 fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
585 if cfg!(esp32) {
586 self.rsa.start_multi();
587 self.rsa.wait_for_idle();
588
589 self.rsa.write_operand_a(operand_b);
590 } else {
591 self.rsa.write_operand_b(operand_b);
592 }
593 }
594}
595
596pub struct RsaMultiplication<'a, 'd, T, Dm>
601where
602 T: RsaMode + Multi,
603 Dm: DriverMode,
604{
605 rsa: &'a mut Rsa<'d, Dm>,
606 phantom: PhantomData<T>,
607}
608
609impl<'a, 'd, T, Dm, const N: usize> RsaMultiplication<'a, 'd, T, Dm>
610where
611 T: RsaMode<InputType = [u32; N]>,
612 T: Multi,
613 Dm: DriverMode,
614{
615 pub fn new(rsa: &'a mut Rsa<'d, Dm>, operand_a: &T::InputType) -> Self {
617 rsa.write_multi_mode(2 * N as u32 / WORDS_PER_INCREMENT - 1, false);
619 rsa.write_operand_a(operand_a);
620
621 Self {
622 rsa,
623 phantom: PhantomData,
624 }
625 }
626
627 pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
629 self.set_up_multiplication(operand_b);
630 self.rsa.start_multi();
631 }
632
633 pub fn read_results<const O: usize>(&mut self, outbuf: &mut T::OutputType)
639 where
640 T: Multi<OutputType = [u32; O]>,
641 {
642 self.rsa.read_results(outbuf);
643 }
644
645 fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
646 self.rsa.write_multi_operand_b(operand_b);
647 }
648}
649
650static WAKER: AtomicWaker = AtomicWaker::new();
651static SIGNALED: AtomicBool = AtomicBool::new(false);
653
654#[must_use = "futures do nothing unless you `.await` or poll them"]
656struct RsaFuture<'a, 'd> {
657 driver: &'a Rsa<'d, Async>,
658}
659
660impl<'a, 'd> RsaFuture<'a, 'd> {
661 fn new(driver: &'a Rsa<'d, Async>) -> Self {
662 SIGNALED.store(false, Ordering::Relaxed);
663
664 driver.internal_enable_disable_interrupt(true);
665
666 Self { driver }
667 }
668
669 fn is_done(&self) -> bool {
670 SIGNALED.load(Ordering::Acquire)
671 }
672}
673
674impl Drop for RsaFuture<'_, '_> {
675 fn drop(&mut self) {
676 self.driver.internal_enable_disable_interrupt(false);
677 }
678}
679
680impl core::future::Future for RsaFuture<'_, '_> {
681 type Output = ();
682
683 fn poll(
684 self: core::pin::Pin<&mut Self>,
685 cx: &mut core::task::Context<'_>,
686 ) -> core::task::Poll<Self::Output> {
687 WAKER.register(cx.waker());
688 if self.is_done() {
689 Poll::Ready(())
690 } else {
691 Poll::Pending
692 }
693 }
694}
695
696impl<T: RsaMode, const N: usize> RsaModularExponentiation<'_, '_, T, Async>
697where
698 T: RsaMode<InputType = [u32; N]>,
699{
700 pub async fn exponentiation(
702 &mut self,
703 base: &T::InputType,
704 r: &T::InputType,
705 outbuf: &mut T::InputType,
706 ) {
707 self.set_up_exponentiation(base, r);
708 let fut = RsaFuture::new(self.rsa);
709 self.rsa.start_modexp();
710 fut.await;
711 self.rsa.read_out(outbuf);
712 }
713}
714
715impl<T: RsaMode, const N: usize> RsaModularMultiplication<'_, '_, T, Async>
716where
717 T: RsaMode<InputType = [u32; N]>,
718{
719 pub async fn modular_multiplication(
721 &mut self,
722 operand_b: &T::InputType,
723 outbuf: &mut T::InputType,
724 ) {
725 if cfg!(esp32) {
726 let fut = RsaFuture::new(self.rsa);
727 self.rsa.start_multi();
728 fut.await;
729
730 self.rsa.write_operand_a(operand_b);
731 } else {
732 self.set_up_modular_multiplication(operand_b);
733 }
734
735 let fut = RsaFuture::new(self.rsa);
736 self.rsa.start_modmulti();
737 fut.await;
738 self.rsa.read_out(outbuf);
739 }
740}
741
742impl<T: RsaMode + Multi, const N: usize> RsaMultiplication<'_, '_, T, Async>
743where
744 T: RsaMode<InputType = [u32; N]>,
745{
746 pub async fn multiplication<const O: usize>(
748 &mut self,
749 operand_b: &T::InputType,
750 outbuf: &mut T::OutputType,
751 ) where
752 T: Multi<OutputType = [u32; O]>,
753 {
754 self.set_up_multiplication(operand_b);
755 let fut = RsaFuture::new(self.rsa);
756 self.rsa.start_multi();
757 fut.await;
758 self.rsa.read_out(outbuf);
759 }
760}
761
762#[handler]
763pub(super) fn rsa_interrupt_handler() {
765 let rsa = RSA::regs();
766 SIGNALED.store(true, Ordering::Release);
767 cfg_if::cfg_if! {
768 if #[cfg(esp32)] {
769 rsa.interrupt().write(|w| w.interrupt().set_bit());
770 } else {
771 rsa.int_clr().write(|w| w.int_clr().set_bit());
772 }
773 }
774
775 WAKER.wake();
776}
777
778static RSA_WORK_QUEUE: WorkQueue<RsaWorkItem> = WorkQueue::new();
779const RSA_VTABLE: VTable<RsaWorkItem> = VTable {
780 post: |driver, item| {
781 let driver = unsafe { RsaBackend::from_raw(driver) };
783 Some(driver.process_item(item))
784 },
785 poll: |driver, item| {
786 let driver = unsafe { RsaBackend::from_raw(driver) };
787 driver.process_item(item)
788 },
789 cancel: |driver, item| {
790 let driver = unsafe { RsaBackend::from_raw(driver) };
791 driver.cancel(item)
792 },
793 stop: |driver| {
794 let driver = unsafe { RsaBackend::from_raw(driver) };
795 driver.deinitialize()
796 },
797};
798
799#[derive(Default)]
800enum RsaBackendState<'d> {
801 #[default]
802 Idle,
803 Initializing(Rsa<'d, Blocking>),
804 Ready(Rsa<'d, Blocking>),
805 #[cfg(esp32)]
806 ModularMultiplicationRoundOne(Rsa<'d, Blocking>),
807 Processing(Rsa<'d, Blocking>),
808}
809
810#[procmacros::doc_replace]
811pub struct RsaBackend<'d> {
841 peri: RSA<'d>,
842 state: RsaBackendState<'d>,
843}
844
845impl<'d> RsaBackend<'d> {
846 #[procmacros::doc_replace]
847 pub fn new(rsa: RSA<'d>) -> Self {
859 Self {
860 peri: rsa,
861 state: RsaBackendState::Idle,
862 }
863 }
864
865 #[procmacros::doc_replace]
866 pub fn start(&mut self) -> RsaWorkQueueDriver<'_, 'd> {
882 RsaWorkQueueDriver {
883 inner: WorkQueueDriver::new(self, RSA_VTABLE, &RSA_WORK_QUEUE),
884 }
885 }
886
887 unsafe fn from_raw<'any>(ptr: NonNull<()>) -> &'any mut Self {
890 unsafe { ptr.cast::<RsaBackend<'_>>().as_mut() }
891 }
892
893 fn process_item(&mut self, item: &mut RsaWorkItem) -> work_queue::Poll {
894 match core::mem::take(&mut self.state) {
895 RsaBackendState::Idle => {
896 let driver = Rsa {
897 rsa: unsafe { self.peri.clone_unchecked() },
898 phantom: PhantomData,
899 #[cfg(not(esp32))]
900 _memory_guard: RsaMemoryPowerGuard::new(),
901 _guard: GenericPeripheralGuard::new(),
902 };
903 self.state = RsaBackendState::Initializing(driver);
904 work_queue::Poll::Pending(true)
905 }
906 RsaBackendState::Initializing(mut rsa) => {
907 self.state = if rsa.ready() {
910 rsa.set_interrupt_handler(rsa_work_queue_handler);
911 rsa.enable_disable_interrupt(true);
912 RsaBackendState::Ready(rsa)
913 } else {
914 RsaBackendState::Initializing(rsa)
915 };
916 work_queue::Poll::Pending(true)
917 }
918 RsaBackendState::Ready(mut rsa) => {
919 #[cfg(not(esp32))]
920 {
921 rsa.disable_constant_time(!item.constant_time);
922 rsa.search_acceleration(item.search_acceleration);
923 }
924
925 match item.operation {
926 RsaOperation::Multiplication { x, y } => {
927 let n = x.len() as u32;
928 rsa.write_operand_a(unsafe { x.as_ref() });
929
930 rsa.write_multi_mode(2 * n / WORDS_PER_INCREMENT - 1, false);
932 rsa.write_multi_operand_b(unsafe { y.as_ref() });
933 rsa.start_multi();
934 }
935
936 RsaOperation::ModularMultiplication {
937 x,
938 #[cfg(not(esp32))]
939 y,
940 m,
941 m_prime,
942 r: r_inv,
943 ..
944 } => {
945 let n = x.len() as u32;
946 rsa.write_operand_a(unsafe { x.as_ref() });
947
948 rsa.write_multi_mode(n / WORDS_PER_INCREMENT - 1, true);
949
950 #[cfg(not(esp32))]
951 rsa.write_operand_b(unsafe { y.as_ref() });
952
953 rsa.write_modulus(unsafe { m.as_ref() });
954 rsa.write_mprime(m_prime);
955 rsa.write_r(unsafe { r_inv.as_ref() });
956
957 rsa.start_modmulti();
958
959 #[cfg(esp32)]
960 {
961 self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
964
965 return work_queue::Poll::Pending(false);
966 }
967 }
968 RsaOperation::ModularExponentiation {
969 x,
970 y,
971 m,
972 m_prime,
973 r_inv,
974 } => {
975 let n = x.len() as u32;
976 rsa.write_operand_a(unsafe { x.as_ref() });
977
978 rsa.write_modexp_mode(n / WORDS_PER_INCREMENT - 1);
979 rsa.write_operand_b(unsafe { y.as_ref() });
980 rsa.write_modulus(unsafe { m.as_ref() });
981 rsa.write_mprime(m_prime);
982 rsa.write_r(unsafe { r_inv.as_ref() });
983
984 #[cfg(not(esp32))]
985 if item.search_acceleration {
986 fn find_search_pos(exponent: &[u32]) -> u32 {
987 for (i, byte) in exponent.iter().rev().enumerate() {
988 if *byte == 0 {
989 continue;
990 }
991 return (exponent.len() * 32) as u32
992 - (byte.leading_zeros() + i as u32 * 32)
993 - 1;
994 }
995 0
996 }
997 rsa.write_search_position(find_search_pos(unsafe { y.as_ref() }));
998 }
999
1000 rsa.start_modexp();
1001 }
1002 }
1003
1004 self.state = RsaBackendState::Processing(rsa);
1005
1006 work_queue::Poll::Pending(false)
1007 }
1008
1009 #[cfg(esp32)]
1010 RsaBackendState::ModularMultiplicationRoundOne(mut rsa) => {
1011 if rsa.is_idle() {
1012 let RsaOperation::ModularMultiplication { y, .. } = item.operation else {
1013 unreachable!();
1014 };
1015
1016 rsa.write_operand_a(unsafe { y.as_ref() });
1018 rsa.start_modmulti();
1019
1020 self.state = RsaBackendState::Processing(rsa);
1021 } else {
1022 self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
1024 }
1025 work_queue::Poll::Pending(false)
1026 }
1027
1028 RsaBackendState::Processing(rsa) => {
1029 if rsa.is_idle() {
1030 rsa.read_out(unsafe { item.result.as_mut() });
1031
1032 self.state = RsaBackendState::Ready(rsa);
1033 work_queue::Poll::Ready(Status::Completed)
1034 } else {
1035 self.state = RsaBackendState::Processing(rsa);
1036 work_queue::Poll::Pending(false)
1037 }
1038 }
1039 }
1040 }
1041
1042 fn cancel(&mut self, _item: &mut RsaWorkItem) {
1043 self.state = RsaBackendState::Idle;
1046 }
1047
1048 fn deinitialize(&mut self) {
1049 self.state = RsaBackendState::Idle;
1050 }
1051}
1052
1053pub struct RsaWorkQueueDriver<'t, 'd> {
1059 inner: WorkQueueDriver<'t, RsaBackend<'d>, RsaWorkItem>,
1060}
1061
1062impl<'t, 'd> RsaWorkQueueDriver<'t, 'd> {
1063 pub fn stop(self) -> impl Future<Output = ()> {
1065 self.inner.stop()
1066 }
1067}
1068
1069#[derive(Clone)]
1070struct RsaWorkItem {
1071 #[cfg(not(esp32))]
1073 search_acceleration: bool,
1074 #[cfg(not(esp32))]
1075 constant_time: bool,
1076
1077 operation: RsaOperation,
1079 result: NonNull<[u32]>,
1080}
1081
1082unsafe impl Sync for RsaWorkItem {}
1083unsafe impl Send for RsaWorkItem {}
1084
1085#[derive(Clone)]
1086enum RsaOperation {
1087 Multiplication {
1090 x: NonNull<[u32]>,
1091 y: NonNull<[u32]>,
1092 },
1093 ModularMultiplication {
1095 x: NonNull<[u32]>,
1096 y: NonNull<[u32]>,
1097 m: NonNull<[u32]>,
1098 r: NonNull<[u32]>,
1099 m_prime: u32,
1100 },
1101 ModularExponentiation {
1103 x: NonNull<[u32]>,
1104 y: NonNull<[u32]>,
1105 m: NonNull<[u32]>,
1106 r_inv: NonNull<[u32]>,
1107 m_prime: u32,
1108 },
1109}
1110
1111#[handler]
1112#[ram]
1113fn rsa_work_queue_handler() {
1114 if !RSA_WORK_QUEUE.process() {
1115 cfg_if::cfg_if! {
1118 if #[cfg(esp32)] {
1119 RSA::regs().interrupt().write(|w| w.interrupt().set_bit());
1120 } else {
1121 RSA::regs().int_clr().write(|w| w.int_clr().set_bit());
1122 }
1123 }
1124 }
1125}
1126
1127#[cfg_attr(
1134 not(esp32),
1135 doc = " \nThe context is created with a secure configuration by default. You can enable hardware acceleration
1136 options using [enable_search_acceleration][Self::enable_search_acceleration] and
1137 [enable_acceleration][Self::enable_acceleration] when appropriate."
1138)]
1139#[derive(Clone)]
1140pub struct RsaContext {
1141 frontend: WorkQueueFrontend<RsaWorkItem>,
1142}
1143
1144impl Default for RsaContext {
1145 fn default() -> Self {
1146 Self::new()
1147 }
1148}
1149
1150impl RsaContext {
1151 pub fn new() -> Self {
1153 Self {
1154 frontend: WorkQueueFrontend::new(RsaWorkItem {
1155 #[cfg(not(esp32))]
1156 search_acceleration: false,
1157 #[cfg(not(esp32))]
1158 constant_time: true,
1159 operation: RsaOperation::Multiplication {
1160 x: NonNull::from(&[]),
1161 y: NonNull::from(&[]),
1162 },
1163 result: NonNull::from(&mut []),
1164 }),
1165 }
1166 }
1167
1168 #[cfg(not(esp32))]
1169 #[doc = trm_markdown_link!("rsa")]
1179 pub fn enable_search_acceleration(&mut self) {
1180 self.frontend.data_mut().search_acceleration = true;
1181 }
1182
1183 #[cfg(not(esp32))]
1184 #[doc = trm_markdown_link!("rsa")]
1195 pub fn enable_acceleration(&mut self) {
1196 self.frontend.data_mut().constant_time = false;
1197 }
1198
1199 fn post(&mut self) -> RsaHandle<'_> {
1200 RsaHandle(self.frontend.post(&RSA_WORK_QUEUE))
1201 }
1202
1203 #[procmacros::doc_replace]
1204 pub fn modular_exponentiate<'t, OP>(
1269 &'t mut self,
1270 x: &'t OP::InputType,
1271 y: &'t OP::InputType,
1272 m: &'t OP::InputType,
1273 r: &'t OP::InputType,
1274 m_prime: u32,
1275 result: &'t mut OP::InputType,
1276 ) -> RsaHandle<'t>
1277 where
1278 OP: RsaMode,
1279 {
1280 self.frontend.data_mut().operation = RsaOperation::ModularExponentiation {
1281 x: NonNull::from(x.as_ref()),
1282 y: NonNull::from(y.as_ref()),
1283 m: NonNull::from(m.as_ref()),
1284 r_inv: NonNull::from(r.as_ref()),
1285 m_prime,
1286 };
1287 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1288 self.post()
1289 }
1290
1291 pub fn modular_multiply<'t, OP>(
1307 &'t mut self,
1308 x: &'t OP::InputType,
1309 y: &'t OP::InputType,
1310 m: &'t OP::InputType,
1311 r: &'t OP::InputType,
1312 m_prime: u32,
1313 result: &'t mut OP::InputType,
1314 ) -> RsaHandle<'t>
1315 where
1316 OP: RsaMode,
1317 {
1318 self.frontend.data_mut().operation = RsaOperation::ModularMultiplication {
1319 x: NonNull::from(x.as_ref()),
1320 y: NonNull::from(y.as_ref()),
1321 m: NonNull::from(m.as_ref()),
1322 r: NonNull::from(r.as_ref()),
1323 m_prime,
1324 };
1325 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1326 self.post()
1327 }
1328
1329 #[procmacros::doc_replace]
1330 pub fn multiply<'t, OP>(
1360 &'t mut self,
1361 x: &'t OP::InputType,
1362 y: &'t OP::InputType,
1363 result: &'t mut OP::OutputType,
1364 ) -> RsaHandle<'t>
1365 where
1366 OP: Multi,
1367 {
1368 self.frontend.data_mut().operation = RsaOperation::Multiplication {
1369 x: NonNull::from(x.as_ref()),
1370 y: NonNull::from(y.as_ref()),
1371 };
1372 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1373 self.post()
1374 }
1375}
1376
1377pub struct RsaHandle<'t>(work_queue::Handle<'t, RsaWorkItem>);
1379
1380impl RsaHandle<'_> {
1381 #[inline]
1383 pub fn poll(&mut self) -> bool {
1384 self.0.poll()
1385 }
1386
1387 #[inline]
1389 pub fn wait_blocking(mut self) {
1390 while !self.poll() {}
1391 }
1392
1393 #[inline]
1395 pub fn wait(&mut self) -> impl Future<Output = Status> {
1396 self.0.wait()
1397 }
1398}