From 8ea6684335cef59fea05686d99fcccfa9abf601d Mon Sep 17 00:00:00 2001 From: numzero Date: Sat, 11 Jan 2025 00:18:58 +0300 Subject: [PATCH] A fully async mailbox --- src/mailbox.rs | 86 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 30 deletions(-) diff --git a/src/mailbox.rs b/src/mailbox.rs index ddb04ef..4d6084f 100644 --- a/src/mailbox.rs +++ b/src/mailbox.rs @@ -1,5 +1,6 @@ use std::{ future::Future, + pin::Pin, sync::Mutex, task::{Context, Poll, Waker}, }; @@ -8,45 +9,70 @@ pub struct Mailbox(Mutex>); impl Mailbox { pub fn new() -> Self { - Mailbox(Mutex::new(MailboxInner { - value: None, - waker: None, - })) + Mailbox(Mutex::new(MailboxInner::Empty(None))) } } -struct MailboxInner { - value: Option, - waker: Option, +enum MailboxInner { + Empty(Option), + Full(Option, T), } -struct MailboxFuture<'a, T>(&'a Mailbox); -impl<'a, T> Future for MailboxFuture<'a, T> { +struct MailboxPut<'a, T>(&'a Mailbox, Option); +impl<'a, T> Unpin for MailboxPut<'a, T> {} +impl<'a, T> Future for MailboxPut<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut mb = self.0 .0.lock().unwrap(); + match &mut *mb { + MailboxInner::Empty(ref mut waker) => { + let waker = waker.take(); + *mb = MailboxInner::Full(None, self.1.take().expect("Mailbox future overpooled!")); + drop(mb); + if let Some(waker) = waker { + waker.wake(); + } + Poll::Ready(()) + } + MailboxInner::Full(ref mut waker, _) => { + *waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} + +struct MailboxGet<'a, T>(&'a Mailbox); +impl<'a, T> Future for MailboxGet<'a, T> { type Output = T; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut mb = self.0 .0.lock().unwrap(); - if let Some(value) = mb.value.take() { - return Poll::Ready(value); + let old = std::mem::replace(&mut *mb, MailboxInner::Empty(None)); + match old { + MailboxInner::Empty(_) => { + *mb = MailboxInner::Empty(Some(cx.waker().clone())); + Poll::Pending + } + MailboxInner::Full(waker, value) => { + drop(mb); + if let Some(waker) = waker { + waker.wake(); + } + Poll::Ready(value) + } } - mb.waker = Some(cx.waker().clone()); - Poll::Pending } } impl Mailbox { - pub fn put(&self, value: T) { - let mut mb = self.0.lock().unwrap(); - mb.value = Some(value); - let Some(waker) = mb.waker.take() else { - return; - }; - drop(mb); - waker.wake(); + pub fn put(&self, value: T) -> impl Future + '_ { + MailboxPut(&self, Some(value)) } pub fn get(&self) -> impl Future + '_ { - MailboxFuture(&self) + MailboxGet(&self) } } @@ -78,7 +104,7 @@ mod tests { #[test] fn test_mailbox_once() { let mb = Mailbox::::new(); - mb.put(42); + pollster::block_on(mb.put(42)); assert_eq!(pollster::block_on(mb.get()), 42); assert!(!ready(mb.get())); } @@ -87,16 +113,16 @@ mod tests { fn test_mailbox_once_oor() { let mb = Mailbox::::new(); let f = mb.get(); - mb.put(42); + pollster::block_on(mb.put(42)); assert_eq!(pollster::block_on(f), 42); } #[test] fn test_mailbox_twice_no_wait() { let mb = Mailbox::::new(); - mb.put(42); - mb.put(13); - assert_eq!(pollster::block_on(mb.get()), 13); + pollster::block_on(mb.put(42)); + assert!(!ready(mb.put(13))); + assert_eq!(pollster::block_on(mb.get()), 42); assert!(!ready(mb.get())); } @@ -104,10 +130,10 @@ mod tests { fn test_mailbox_twice_in_order() { let mb = Mailbox::::new(); let f = mb.get(); - mb.put(42); + pollster::block_on(mb.put(42)); assert_eq!(pollster::block_on(f), 42); let f = mb.get(); - mb.put(13); + pollster::block_on(mb.put(13)); assert_eq!(pollster::block_on(f), 13); assert!(!ready(mb.get())); }