From 33db3cb15d2503e2797223b485e3d5095c6b6c3f Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Fri, 17 Nov 2023 16:54:15 +0400 Subject: [PATCH] AsynIterator --- Cargo.toml | 1 + examples/demo.rs | 27 ++++++--- src/async_iter.rs | 141 ++++++++++++++++++++++++++++++++++++++++++++++ src/builder.rs | 8 ++- src/chan.rs | 12 ++-- src/error.rs | 61 ++++++++++++-------- src/handler.rs | 89 +++++++++++++++++------------ src/lib.rs | 90 +++++++++++++++++++++++------ src/message.rs | 69 ++++++++++------------- 9 files changed, 369 insertions(+), 129 deletions(-) create mode 100644 src/async_iter.rs diff --git a/Cargo.toml b/Cargo.toml index 1ea929b..35ff9d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow = "1.0.75" boxcar = "0.2.3" dashmap = "5.5.0" futures = "0.3.28" diff --git a/examples/demo.rs b/examples/demo.rs index 426be92..17c44a4 100644 --- a/examples/demo.rs +++ b/examples/demo.rs @@ -3,7 +3,8 @@ use std::sync::Arc; -use messagebus::{Builder, Bus, Error, Handler, IntoMessageStream, Message}; +use anyhow::Error; +use messagebus::{Builder, Bus, Handler, IntoMessages, Message}; #[derive(Debug, Clone)] pub struct Msg(pub i32); @@ -15,34 +16,42 @@ pub struct Processor { impl Handler for Processor { type Result = (); + type Error = Error; + async fn handle( &mut self, _msg: Msg, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { - Ok(()) - } - - async fn finalize(self) -> Result<(), Error> { + _bus: Bus, + ) -> Result, Self::Error> { Ok(()) } async fn handle_error( &mut self, - _err: Error, + _err: messagebus::Error, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { + _bus: messagebus::Bus, + ) -> Result + Send + '_, Self::Error> { Ok(None) } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { + Ok(()) + } } struct ProcSpawner; impl Builder for ProcSpawner { type Context = Processor; - async fn build(&self, stream_id: u32, _task_id: u32) -> Result { + async fn build( + &self, + stream_id: u32, + _task_id: u32, + ) -> Result { Ok(Processor { _state: stream_id as _, }) diff --git a/src/async_iter.rs b/src/async_iter.rs new file mode 100644 index 0000000..0681b12 --- /dev/null +++ b/src/async_iter.rs @@ -0,0 +1,141 @@ +use std::{marker::PhantomData, pin::Pin}; + +use futures::{Stream, StreamExt}; + +pub trait AsyncIterator: Send { + type Item: Send; + + fn next(self: Pin<&mut Self>) -> impl futures::Future> + Send + '_; + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } + + fn map U>(self, cb: C) -> impl AsyncIterator + where + Self: Sized, + { + Map { inner: self, cb } + } +} + +pin_project_lite::pin_project! { + pub struct MapU> { + #[pin] + inner: I, + cb: F, + } +} + +impl U> AsyncIterator for Map { + type Item = U; + + async fn next(self: Pin<&mut Self>) -> Option { + let this = self.project(); + Some((this.cb)(this.inner.next().await?)) + } +} + +pub struct Iter { + inner: I, +} + +impl AsyncIterator for Iter +where + I::Item: Send, +{ + type Item = I::Item; + + #[inline] + async fn next(self: Pin<&mut Self>) -> Option { + self.get_mut().inner.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +pub fn iter(inner: I) -> Iter +where + I::IntoIter: Send + Unpin, + I::Item: Send, +{ + Iter { + inner: inner.into_iter(), + } +} +pin_project_lite::pin_project! { + pub struct StreamIter { + #[pin] + inner: S, + } +} + +impl AsyncIterator for StreamIter +where + S::Item: Send, +{ + type Item = S::Item; + + #[inline] + async fn next(self: Pin<&mut Self>) -> Option { + self.project().inner.next().await + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +pub fn stream(inner: S) -> StreamIter +where + S::Item: Send, +{ + StreamIter { inner } +} + +pub struct Once(Option); +impl AsyncIterator for Once { + type Item = I; + + async fn next(self: Pin<&mut Self>) -> Option { + self.get_mut().0.take() + } +} + +pub fn once(item: I) -> Once { + Once(Some(item)) +} + +pub struct Empty(PhantomData); +impl AsyncIterator for Empty { + type Item = I; + + async fn next(self: Pin<&mut Self>) -> Option { + None + } +} + +pub fn empty() -> Empty { + Empty(Default::default()) +} + +impl AsyncIterator for tokio::sync::mpsc::Receiver { + type Item = I; + + async fn next(self: Pin<&mut Self>) -> Option { + self.get_mut().recv().await + } +} + +impl AsyncIterator for tokio::sync::oneshot::Receiver { + type Item = I; + + async fn next(self: Pin<&mut Self>) -> Option { + match self.await { + Ok(item) => Some(item), + Err(_) => None, + } + } +} diff --git a/src/builder.rs b/src/builder.rs index 91f1d23..a241364 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -149,9 +149,13 @@ where return Ok(self.stream_handlers.get(&stream_id).unwrap().clone()); } - let val = Arc::new((self.callback)(stream_id, task_id).await?); + let val = match (self.callback)(stream_id, task_id).await { + Ok(val) => Arc::new(val), + Err(err) => return Err(err), + }; + self.stream_handlers.insert(stream_id, val.clone()); - Ok(val.clone()) + Ok(val) } fn config(&self, _stream_id: u32) -> Config { diff --git a/src/chan.rs b/src/chan.rs index 44b3260..657f828 100644 --- a/src/chan.rs +++ b/src/chan.rs @@ -88,13 +88,17 @@ pub(crate) struct Sender { impl Sender { pub async fn send(&self, msg: T) -> Result<(), Error> { - self.inner.send(ChannelItem::Value(msg)).await?; - Ok(()) + self.inner + .send(ChannelItem::Value(msg)) + .await + .map_err(Error::SendError) } pub async fn stop(&self) -> Result<(), Error> { - self.inner.send(ChannelItem::Close).await?; - Ok(()) + self.inner + .send(ChannelItem::Close) + .await + .map_err(Error::SendError) } pub fn close(&self) -> Result<(), Error> { diff --git a/src/error.rs b/src/error.rs index 013ddff..6518c54 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,19 +1,43 @@ -use std::fmt; +use std::{fmt, sync::Arc}; -use kanal::ReceiveError; +use kanal::{ReceiveError, SendError}; -#[derive(Debug, PartialEq, Eq)] +use crate::message::ErrorMessage; + +#[derive(Debug)] pub enum Error { HandlerIsNotRegistered, Aborted, - SendError(String), + SendError(SendError), ReceiveError(kanal::ReceiveError), ReorderingDropMessage(u64), + HandlerError(Arc), +} + +impl std::error::Error for Error {} + +impl Eq for Error {} + +impl PartialEq for Error { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Error::HandlerIsNotRegistered, Error::HandlerIsNotRegistered) => true, + (Error::Aborted, Error::Aborted) => true, + (Error::SendError(err1), Error::SendError(err2)) => err1.eq(err2), + (Error::ReceiveError(err1), Error::ReceiveError(err2)) => err1.eq(err2), + (Error::ReorderingDropMessage(idx1), Error::ReorderingDropMessage(idx2)) => { + idx1.eq(idx2) + } + (Error::HandlerError(err1), Error::HandlerError(err2)) => Arc::ptr_eq(err1, err2), + _ => false, + } + } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Error::HandlerError(err) => writeln!(f, "Handler Error: {}", err)?, Error::HandlerIsNotRegistered => writeln!(f, "Handle is not registered!")?, Error::Aborted => writeln!(f, "Operation Aborted!")?, Error::SendError(reason) => writeln!(f, "Channel send error; reason {}", reason)?, @@ -29,12 +53,14 @@ impl fmt::Display for Error { } } -impl std::error::Error for Error {} - impl Clone for Error { fn clone(&self) -> Self { match self { - Error::SendError(err) => Error::SendError(err.clone()), + Error::HandlerError(err) => Error::HandlerError(err.clone()), + Error::SendError(err) => match err { + SendError::Closed => Error::SendError(SendError::Closed), + SendError::ReceiveClosed => Error::SendError(SendError::ReceiveClosed), + }, Error::ReceiveError(err) => match err { ReceiveError::Closed => Error::ReceiveError(ReceiveError::Closed), ReceiveError::SendClosed => Error::ReceiveError(ReceiveError::SendClosed), @@ -46,20 +72,11 @@ impl Clone for Error { } } -impl From for Error { - fn from(value: kanal::SendError) -> Self { - Self::SendError(format!("{}", value)) +impl From for Error +where + E: ErrorMessage, +{ + fn from(error: E) -> Self { + Self::HandlerError(Arc::new(error)) } } - -impl From for Error { - fn from(value: kanal::ReceiveError) -> Self { - Self::ReceiveError(value) - } -} - -// impl From> for Error { -// fn from(value: mpsc::error::SendError) -> Self { -// Self::SendError(format!("{}", value)) -// } -// } diff --git a/src/handler.rs b/src/handler.rs index d522ff1..976b117 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -8,38 +8,44 @@ use std::{ }, }; -use futures::{future, Future, Stream, StreamExt}; +use futures::Future; use tokio::sync::Notify; use crate::{ builder::Builder, chan::Receiver, - message::Msg, + message::{IntoMessages, Msg}, task::{TaskCounter, TaskSpawner}, - BusInner, Error, IntoMessageStream, Message, + AsyncIterator, Bus, BusInner, Error, ErrorMessage, Message, }; pub trait Handler: Send + Sync + 'static { - type Result: Message; + type Result: Message + Unpin; + type Error: ErrorMessage + Unpin; fn handle( &mut self, msg: M, stream_id: u32, task_id: u32, - ) -> impl Future + '_, Error>> + Send + '_; + bus: crate::Bus, + ) -> impl Future< + Output = Result + Send + '_, Self::Error>, + > + Send + + '_; fn handle_error( &mut self, _err: Error, _stream_id: u32, _task_id: u32, - ) -> impl Future + '_, Error>> + Send + '_ - { - future::ready(Ok(None)) - } + _bus: crate::Bus, + ) -> impl Future< + Output = Result + Send + '_, Self::Error>, + > + Send + + '_; - fn finalize(self) -> impl Future> + Send; + fn finalize(self, bus: crate::Bus) -> impl Future> + Send; } pub(crate) struct HandlerSpawner { @@ -72,7 +78,8 @@ where bus: Arc, ) -> Pin> + Send + '_>> { Box::pin(async move { - let bus = bus.clone(); + let bus = Bus { inner: bus.clone() }; + let config = self.builder.config(stream_id); let mut ctx = self.builder.build(stream_id, task_id).await?; @@ -85,42 +92,45 @@ where let res = match msg.inner { Some(Ok(m)) => { send_result( - &bus, + &bus.inner, &index_counter, msg.index, stream_id, &config, Some( - ctx.handle(m, stream_id, task_id) + ctx.handle(m, stream_id, task_id, bus.clone()) .await - .map(IntoMessageStream::into_message_stream), + .map(IntoMessages::into_messages), ), ) .await } Some(Err(err)) => { send_result( - &bus, + &bus.inner, &index_counter, msg.index, stream_id, &config, Some( - ctx.handle_error(err, stream_id, task_id) + ctx.handle_error(err, stream_id, task_id, bus.clone()) .await - .map(IntoMessageStream::into_message_stream), + .map(IntoMessages::into_messages), ), ) .await } None => { - send_result::<>::Result>( - &bus, + send_result::< + >::Result, + >::Error, + >( + &bus.inner, &index_counter, msg.index, stream_id, &config, - None::, _>>, + None::, _>>, ) .await } @@ -143,7 +153,7 @@ where std::any::type_name::() ); - if let Err(err) = ctx.finalize().await { + if let Err(err) = ctx.finalize(bus.clone()).await { println!("TASK FINALIZE ERROR: {:?}", err); } }); @@ -160,13 +170,13 @@ where } } -async fn send_result<'a, M: Message>( +async fn send_result<'a, M: Message, E: ErrorMessage>( bus: &Arc, index_counter: &AtomicU64, index: u64, stream_id: u32, config: &crate::builder::Config, - res: Option> + Send + 'a, Error>>, + res: Option> + Send + 'a, E>>, ) -> Result<(), Error> { let reorder_buff = if config.ordered && config.task_count > 1 { config.task_count @@ -175,21 +185,22 @@ async fn send_result<'a, M: Message>( }; let one = match res { - Some(Ok(stream)) => { - let hint = stream.size_hint(); - + Some(Ok(iter)) => { + let hint = iter.size_hint(); + let mut iter = pin!(iter); match hint { - (0, Some(0)) => None, - (1, Some(1)) => { - let mut stream = pin!(stream); - stream.next().await - } + (_, Some(0)) => None, + (_, Some(1)) => iter.next().await, _ => { - let mut stream = pin!(stream); - while let Some(item) = stream.next().await { + while let Some(item) = iter.as_mut().next().await { let index = index_counter.fetch_add(1, Ordering::Relaxed); - bus.send::(Some(item), index, stream_id, reorder_buff) - .await?; + bus.send::( + Some(item.map_err(Into::into)), + index, + stream_id, + reorder_buff, + ) + .await?; } return Ok(()); } @@ -199,7 +210,13 @@ async fn send_result<'a, M: Message>( None => None, }; - bus.send(one, index, stream_id, reorder_buff).await?; + bus.send( + one.map(|x| x.map_err(Into::into)), + index, + stream_id, + reorder_buff, + ) + .await?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index ba0eb55..5f2831a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![feature(return_position_impl_trait_in_trait)] #![feature(async_fn_in_trait)] +mod async_iter; mod builder; mod chan; mod error; @@ -29,10 +30,11 @@ use reorder_queue::ReorderQueueInner; use task::{TaskCounter, TaskSpawnerWrapper}; use tokio::sync::{Notify, RwLock}; +pub use async_iter::*; pub use builder::{Builder, DefaultBuilder, SharedBuilder}; pub use error::Error; pub use handler::Handler; -pub use message::{IntoMessageStream, Message, MessageIterator, MessageStream, MessageTryIterator}; +pub use message::{ErrorMessage, IntoMessages, Message}; pub const DEFAUL_STREAM_ID: u32 = u32::MAX; pub const DEFAUL_TASK_ID: u32 = 0; @@ -283,9 +285,8 @@ impl BusInner { index, stream_id, }) - .await?; - - Ok(()) + .await + .map_err(Error::SendError) } pub async fn register>(self: Arc, builder: B) @@ -424,8 +425,9 @@ mod tests { use rand::RngCore; use crate::{ - handler::Handler, Bus, DefaultBuilder, Error, IntoMessageStream, Message, SharedBuilder, + stream, Bus, DefaultBuilder, Error, Handler, IntoMessages, Message, SharedBuilder, }; + impl Message for u64 {} impl Message for u32 {} impl Message for i16 {} @@ -435,21 +437,33 @@ mod tests { struct TestProducer; impl Handler for TestProducer { type Result = u64; + type Error = anyhow::Error; async fn handle( &mut self, _msg: u32, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { - Ok(crate::message::MessageStream(stream! { + _bus: Bus, + ) -> Result, Self::Error> { + Ok(stream(stream! { for i in 0u64..10 { yield Ok(i) } })) } - async fn finalize(self) -> Result<(), Error> { + async fn handle_error( + &mut self, + _err: Error, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + Ok(None) + } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { println!("producer finalized"); Ok(()) } @@ -464,22 +478,35 @@ mod tests { impl Handler for Arc { type Result = (); + type Error = anyhow::Error; async fn handle( &mut self, msg: u64, stream_id: u32, task_id: u32, - ) -> Result, Error> { + _bus: Bus, + ) -> Result, Self::Error> { tokio::time::sleep(Duration::from_millis(1000)).await; println!( "[{}] shared consumer handle {}u64 ({}:{})", self.0, msg, stream_id, task_id ); + Ok(()) } - async fn finalize(self) -> Result<(), Error> { + async fn handle_error( + &mut self, + _err: Error, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + Ok(None) + } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { println!("[{}] shared consumer finalized", self.0); Ok(()) } @@ -487,13 +514,15 @@ mod tests { impl Handler for TestConsumer { type Result = (); + type Error = anyhow::Error; async fn handle( &mut self, msg: u64, stream_id: u32, task_id: u32, - ) -> Result, Error> { + _bus: Bus, + ) -> Result, Self::Error> { tokio::time::sleep(Duration::from_millis(100)).await; println!( "[{}] consumer handle {}u64 ({}:{})", @@ -502,7 +531,17 @@ mod tests { Ok(()) } - async fn finalize(self) -> Result<(), Error> { + async fn handle_error( + &mut self, + _err: Error, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + Ok(None) + } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { println!("[{}] consumer finalized", self.0); Ok(()) } @@ -512,13 +551,15 @@ mod tests { impl Handler for Arc { type Result = u16; + type Error = anyhow::Error; async fn handle( &mut self, msg: i16, _stream_id: u32, task_id: u32, - ) -> Result, Error> { + _bus: Bus, + ) -> Result, Self::Error> { if task_id % 2 == 0 { tokio::time::sleep(Duration::from_millis(13)).await; } else { @@ -526,10 +567,20 @@ mod tests { } println!("handle {}", msg); - Ok(msg as u16) + Ok([msg as u16]) } - async fn finalize(self) -> Result<(), Error> { + async fn handle_error( + &mut self, + _err: Error, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + Ok(None) + } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { Ok(()) } } @@ -541,20 +592,22 @@ mod tests { impl Handler for TestCollector { type Result = (); + type Error = anyhow::Error; async fn handle( &mut self, msg: u16, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { + _bus: Bus, + ) -> Result, Self::Error> { println!("{}", msg); self.inner.push(msg); Ok(None) } - async fn finalize(self) -> Result<(), Error> { + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { println!("Checking"); assert_eq!(self.inner, (0u16..1024).collect::>()); Ok(()) @@ -565,7 +618,8 @@ mod tests { err: Error, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { + _bus: Bus, + ) -> Result, Self::Error> { println!("{:?}", err); Ok(None) diff --git a/src/message.rs b/src/message.rs index a883383..91bef14 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,7 +3,16 @@ use std::any::Any; use futures::Stream; -use crate::Error; +use crate::{async_iter::AsyncIterator, Error, Iter, StreamIter}; + +pub trait Message: Any + fmt::Debug + Clone + Send + Sync + Unpin {} +pub trait ErrorMessage: Any + fmt::Debug + fmt::Display + Send + Sync + Unpin {} +pub trait IntoMessages { + fn into_messages(self) -> impl AsyncIterator>; +} + +impl Message for () {} +impl ErrorMessage for anyhow::Error {} #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct Msg { @@ -12,59 +21,43 @@ pub(crate) struct Msg { pub(crate) stream_id: u32, } -pub struct MessageStream(pub S); -pub struct MessageIterator(pub I); -pub struct MessageTryIterator(pub I); - -pub trait Message: Any + fmt::Debug + Clone + Send + Sync + 'static {} -impl Message for () {} - -pub trait IntoMessageStream: Send { - fn into_message_stream(self) -> impl Stream> + Send; -} - -impl> + Send> IntoMessageStream - for MessageStream +impl> + Send> IntoMessages + for StreamIter { - fn into_message_stream(self) -> impl Stream> + Send { - self.0 + fn into_messages(self) -> impl AsyncIterator> { + self } } -impl + Send> IntoMessageStream for MessageIterator { - fn into_message_stream(self) -> impl Stream> { - futures::stream::iter(self.0.map(Ok)) - } -} - -impl> + Send> IntoMessageStream - for MessageTryIterator +impl IntoMessages for Iter +where + I::Item: Message, { - fn into_message_stream(self) -> impl Stream> + Send { - futures::stream::iter(self.0) + fn into_messages(self) -> impl AsyncIterator> { + self.map(Ok) } } -impl IntoMessageStream for Option { - fn into_message_stream(self) -> impl Stream> + Send { - futures::stream::iter(self.into_iter().map(Ok)) +impl IntoMessages<(), E> for () { + fn into_messages(self) -> impl AsyncIterator> { + crate::empty() } } -impl IntoMessageStream for Vec { - fn into_message_stream(self) -> impl Stream> + Send { - futures::stream::iter(self.into_iter().map(Ok)) +impl IntoMessages for [M; N] { + fn into_messages(self) -> impl AsyncIterator> { + crate::iter(self.into_iter().map(Ok)) } } -impl IntoMessageStream for [M; N] { - fn into_message_stream(self) -> impl Stream> + Send { - futures::stream::iter(self.map(Ok)) +impl IntoMessages for Vec { + fn into_messages(self) -> impl AsyncIterator> { + crate::iter(self.into_iter().map(Ok)) } } -impl IntoMessageStream for M { - fn into_message_stream(self) -> impl Stream> + Send { - futures::stream::iter([Ok(self)]) +impl IntoMessages for Option { + fn into_messages(self) -> impl AsyncIterator> { + crate::iter(self.into_iter().map(Ok)) } }