A fully async mailbox

This commit is contained in:
numzero 2025-01-11 00:18:58 +03:00
parent a617d21623
commit 8ea6684335

View File

@ -1,5 +1,6 @@
use std::{ use std::{
future::Future, future::Future,
pin::Pin,
sync::Mutex, sync::Mutex,
task::{Context, Poll, Waker}, task::{Context, Poll, Waker},
}; };
@ -8,45 +9,70 @@ pub struct Mailbox<T>(Mutex<MailboxInner<T>>);
impl<T> Mailbox<T> { impl<T> Mailbox<T> {
pub fn new() -> Self { pub fn new() -> Self {
Mailbox(Mutex::new(MailboxInner { Mailbox(Mutex::new(MailboxInner::Empty(None)))
value: None,
waker: None,
}))
} }
} }
struct MailboxInner<T> { enum MailboxInner<T> {
value: Option<T>, Empty(Option<Waker>),
waker: Option<Waker>, Full(Option<Waker>, T),
} }
struct MailboxFuture<'a, T>(&'a Mailbox<T>); struct MailboxPut<'a, T>(&'a Mailbox<T>, Option<T>);
impl<'a, T> Future for MailboxFuture<'a, T> { impl<'a, T> Unpin for MailboxPut<'a, T> {}
impl<'a, T> Future for MailboxPut<'a, T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut mb = self.0 .0.lock().unwrap();
match &mut *mb {
MailboxInner::Empty(ref mut waker) => {
let waker = waker.take();
*mb = MailboxInner::Full(None, self.1.take().expect("Mailbox future overpooled!"));
drop(mb);
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(())
}
MailboxInner::Full(ref mut waker, _) => {
*waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
struct MailboxGet<'a, T>(&'a Mailbox<T>);
impl<'a, T> Future for MailboxGet<'a, T> {
type Output = T; type Output = T;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut mb = self.0 .0.lock().unwrap(); let mut mb = self.0 .0.lock().unwrap();
if let Some(value) = mb.value.take() { let old = std::mem::replace(&mut *mb, MailboxInner::Empty(None));
return Poll::Ready(value); match old {
} MailboxInner::Empty(_) => {
mb.waker = Some(cx.waker().clone()); *mb = MailboxInner::Empty(Some(cx.waker().clone()));
Poll::Pending Poll::Pending
} }
MailboxInner::Full(waker, value) => {
drop(mb);
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(value)
}
}
}
} }
impl<T> Mailbox<T> { impl<T> Mailbox<T> {
pub fn put(&self, value: T) { pub fn put(&self, value: T) -> impl Future<Output = ()> + '_ {
let mut mb = self.0.lock().unwrap(); MailboxPut(&self, Some(value))
mb.value = Some(value);
let Some(waker) = mb.waker.take() else {
return;
};
drop(mb);
waker.wake();
} }
pub fn get(&self) -> impl Future<Output = T> + '_ { pub fn get(&self) -> impl Future<Output = T> + '_ {
MailboxFuture(&self) MailboxGet(&self)
} }
} }
@ -78,7 +104,7 @@ mod tests {
#[test] #[test]
fn test_mailbox_once() { fn test_mailbox_once() {
let mb = Mailbox::<i32>::new(); let mb = Mailbox::<i32>::new();
mb.put(42); pollster::block_on(mb.put(42));
assert_eq!(pollster::block_on(mb.get()), 42); assert_eq!(pollster::block_on(mb.get()), 42);
assert!(!ready(mb.get())); assert!(!ready(mb.get()));
} }
@ -87,16 +113,16 @@ mod tests {
fn test_mailbox_once_oor() { fn test_mailbox_once_oor() {
let mb = Mailbox::<i32>::new(); let mb = Mailbox::<i32>::new();
let f = mb.get(); let f = mb.get();
mb.put(42); pollster::block_on(mb.put(42));
assert_eq!(pollster::block_on(f), 42); assert_eq!(pollster::block_on(f), 42);
} }
#[test] #[test]
fn test_mailbox_twice_no_wait() { fn test_mailbox_twice_no_wait() {
let mb = Mailbox::<i32>::new(); let mb = Mailbox::<i32>::new();
mb.put(42); pollster::block_on(mb.put(42));
mb.put(13); assert!(!ready(mb.put(13)));
assert_eq!(pollster::block_on(mb.get()), 13); assert_eq!(pollster::block_on(mb.get()), 42);
assert!(!ready(mb.get())); assert!(!ready(mb.get()));
} }
@ -104,10 +130,10 @@ mod tests {
fn test_mailbox_twice_in_order() { fn test_mailbox_twice_in_order() {
let mb = Mailbox::<i32>::new(); let mb = Mailbox::<i32>::new();
let f = mb.get(); let f = mb.get();
mb.put(42); pollster::block_on(mb.put(42));
assert_eq!(pollster::block_on(f), 42); assert_eq!(pollster::block_on(f), 42);
let f = mb.get(); let f = mb.get();
mb.put(13); pollster::block_on(mb.put(13));
assert_eq!(pollster::block_on(f), 13); assert_eq!(pollster::block_on(f), 13);
assert!(!ready(mb.get())); assert!(!ready(mb.get()));
} }