From 4d1cced70e35c568778b809be38e2650cc246a9b Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Mon, 13 Nov 2023 17:57:01 +0400 Subject: [PATCH] Reordering queue part 2 --- Cargo.toml | 1 + examples/demo.rs | 6 +- src/builder.rs | 10 + src/chan.rs | 25 ++- src/error.rs | 26 ++- src/handler.rs | 154 +++++++++++--- src/lib.rs | 481 ++++++++++++++++++++++++------------------- src/message.rs | 70 +++++-- src/producer.rs | 145 ------------- src/reorder_queue.rs | 47 +++-- src/task.rs | 70 ++++--- 11 files changed, 571 insertions(+), 464 deletions(-) delete mode 100644 src/producer.rs diff --git a/Cargo.toml b/Cargo.toml index 46a98c0..1ea929b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ dashmap = "5.5.0" futures = "0.3.28" kanal = "0.1.0-pre8" log = "0.4.20" +pin-project-lite = "0.2.13" priority-queue = "1.3.2" rand = { version = "0.8.5", default-features = false, features = ["std_rng", "std"] } rand_xorshift = "0.3.0" diff --git a/examples/demo.rs b/examples/demo.rs index 5d890b1..426be92 100644 --- a/examples/demo.rs +++ b/examples/demo.rs @@ -3,7 +3,7 @@ use std::sync::Arc; -use messagebus::{Builder, Bus, Error, Handler, IntoMessage, Message}; +use messagebus::{Builder, Bus, Error, Handler, IntoMessageStream, Message}; #[derive(Debug, Clone)] pub struct Msg(pub i32); @@ -20,7 +20,7 @@ impl Handler for Processor { _msg: Msg, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { + ) -> Result, Error> { Ok(()) } @@ -33,7 +33,7 @@ impl Handler for Processor { _err: Error, _stream_id: u32, _task_id: u32, - ) -> Result, Error> { + ) -> Result, Error> { Ok(None) } } diff --git a/src/builder.rs b/src/builder.rs index 0f96e87..91f1d23 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -57,6 +57,16 @@ impl DefaultBuilder { _m: PhantomData, } } + + pub fn ordered(self) -> Self { + let mut config = self.config; + config.ordered = true; + + Self { + config, + _m: PhantomData, + } + } } impl Builder for DefaultBuilder { diff --git a/src/chan.rs b/src/chan.rs index 20924ab..44b3260 100644 --- a/src/chan.rs +++ b/src/chan.rs @@ -1,4 +1,7 @@ -use std::{any::Any, pin::Pin}; +use std::{ + any::{Any, TypeId}, + pin::Pin, +}; use futures::Future; @@ -9,6 +12,26 @@ enum ChannelItem { Close, } +pub(crate) trait AbstractSender: Any + Send + Sync { + fn upcast(&self) -> &(dyn Any + Send + Sync); + fn close(&self); + fn message_type_id(&self) -> TypeId; +} + +impl AbstractSender for kanal::AsyncSender { + fn upcast(&self) -> &(dyn Any + Send + Sync) { + self + } + + fn close(&self) { + self.close(); + } + + fn message_type_id(&self) -> TypeId { + TypeId::of::() + } +} + pub(crate) trait BusSenderClose: Any + Send + Sync { fn upcast(&self) -> &(dyn Any + Send + Sync); fn is_producer(&self) -> bool; diff --git a/src/error.rs b/src/error.rs index 8fcbb0b..013ddff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,5 @@ +use std::fmt; + use kanal::ReceiveError; #[derive(Debug, PartialEq, Eq)] @@ -6,9 +8,29 @@ pub enum Error { Aborted, SendError(String), ReceiveError(kanal::ReceiveError), - ReorderingMissedMessage(u64), + ReorderingDropMessage(u64), } +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::HandlerIsNotRegistered => writeln!(f, "Handle is not registered!")?, + Error::Aborted => writeln!(f, "Operation Aborted!")?, + Error::SendError(reason) => writeln!(f, "Channel send error; reason {}", reason)?, + Error::ReceiveError(reason) => writeln!(f, "Channel receive error; reason {}", reason)?, + Error::ReorderingDropMessage(index) => writeln!( + f, + "Reordering drop message #{} (out of bound the queue)", + index + )?, + } + + Ok(()) + } +} + +impl std::error::Error for Error {} + impl Clone for Error { fn clone(&self) -> Self { match self { @@ -19,7 +41,7 @@ impl Clone for Error { }, Error::HandlerIsNotRegistered => Error::HandlerIsNotRegistered, Error::Aborted => Error::Aborted, - Error::ReorderingMissedMessage(idx) => Error::ReorderingMissedMessage(*idx), + Error::ReorderingDropMessage(idx) => Error::ReorderingDropMessage(*idx), } } } diff --git a/src/handler.rs b/src/handler.rs index aba8320..d522ff1 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,6 +1,14 @@ -use std::{any::Any, marker::PhantomData, pin::Pin, sync::Arc}; +use std::{ + any::{type_name, Any}, + marker::PhantomData, + pin::{pin, Pin}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; -use futures::Future; +use futures::{future, Future, Stream, StreamExt}; use tokio::sync::Notify; use crate::{ @@ -8,7 +16,7 @@ use crate::{ chan::Receiver, message::Msg, task::{TaskCounter, TaskSpawner}, - BusInner, Error, IntoMessage, Message, + BusInner, Error, IntoMessageStream, Message, }; pub trait Handler: Send + Sync + 'static { @@ -19,14 +27,17 @@ pub trait Handler: Send + Sync + 'static { msg: M, stream_id: u32, task_id: u32, - ) -> impl Future, Error>> + Send + '_; + ) -> impl Future + '_, Error>> + Send + '_; fn handle_error( &mut self, - err: Error, - stream_id: u32, - task_id: u32, - ) -> impl Future, Error>> + Send + '_; + _err: Error, + _stream_id: u32, + _task_id: u32, + ) -> impl Future + '_, Error>> + Send + '_ + { + future::ready(Ok(None)) + } fn finalize(self) -> impl Future> + Send; } @@ -56,47 +67,82 @@ where task_id: u32, _abort: Arc, task_counter: Arc, + spawn_counter: Arc, + index_counter: Arc, bus: Arc, ) -> Pin> + Send + '_>> { Box::pin(async move { let bus = bus.clone(); + let config = self.builder.config(stream_id); let mut ctx = self.builder.build(stream_id, task_id).await?; let _handle = tokio::spawn(async move { + let _test = spawn_counter.clone().lease_unit(|| true); + while let Some(msg) = rx.recv().await { - task_counter.inc_running(); + let _test = task_counter.clone().lease_unit(|| rx.is_empty()); - let resp = match msg.inner { - Ok(m) => ctx - .handle(m, stream_id, task_id) + let res = match msg.inner { + Some(Ok(m)) => { + send_result( + &bus, + &index_counter, + msg.index, + stream_id, + &config, + Some( + ctx.handle(m, stream_id, task_id) + .await + .map(IntoMessageStream::into_message_stream), + ), + ) .await - .map(IntoMessage::into_message), - - Err(err) => ctx - .handle_error(err, stream_id, task_id) + } + Some(Err(err)) => { + send_result( + &bus, + &index_counter, + msg.index, + stream_id, + &config, + Some( + ctx.handle_error(err, stream_id, task_id) + .await + .map(IntoMessageStream::into_message_stream), + ), + ) .await - .map(IntoMessage::into_message), + } + None => { + send_result::<>::Result>( + &bus, + &index_counter, + msg.index, + stream_id, + &config, + None::, _>>, + ) + .await + } }; - let Some(inner) = resp.transpose() else { - let _ = bus.send_skip::(stream_id, msg.index).await; - continue; - }; - - if let Err(err) = bus - .send(Msg { - inner, - index: msg.index, + if let Err(err) = res { + println!( + "Messagebus Send Error: {}/{} {}: {}", stream_id, - }) - .await - { - println!("BUS SEND ERROR: {:?}", err); + task_id, + type_name::<>::Result>(), + err, + ); } - - task_counter.dec_running(rx.is_empty()); } + println!( + "TASK #{} of type `{}` ENDED", + task_id, + std::any::type_name::() + ); + if let Err(err) = ctx.finalize().await { println!("TASK FINALIZE ERROR: {:?}", err); } @@ -113,3 +159,47 @@ where self.builder.config(stream_id) } } + +async fn send_result<'a, M: Message>( + bus: &Arc, + index_counter: &AtomicU64, + index: u64, + stream_id: u32, + config: &crate::builder::Config, + res: Option> + Send + 'a, Error>>, +) -> Result<(), Error> { + let reorder_buff = if config.ordered && config.task_count > 1 { + config.task_count + } else { + 0 + }; + + let one = match res { + Some(Ok(stream)) => { + let hint = stream.size_hint(); + + match hint { + (0, Some(0)) => None, + (1, Some(1)) => { + let mut stream = pin!(stream); + stream.next().await + } + _ => { + let mut stream = pin!(stream); + while let Some(item) = stream.next().await { + let index = index_counter.fetch_add(1, Ordering::Relaxed); + bus.send::(Some(item), index, stream_id, reorder_buff) + .await?; + } + return Ok(()); + } + } + } + Some(Err(err)) => Some(Err(err)), + None => None, + }; + + bus.send(one, index, stream_id, reorder_buff).await?; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index de4fb2b..ba0eb55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,6 @@ mod chan; mod error; mod handler; mod message; -mod producer; mod rand; mod reorder_queue; mod task; @@ -15,38 +14,40 @@ use std::{ any::{Any, TypeId}, collections::HashMap, sync::{ - atomic::{AtomicU64, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, }; use builder::Config; -use chan::{channel, BusSender, BusSenderClose, Receiver, Sender}; +use chan::{channel, AbstractSender, BusSender, BusSenderClose, Receiver, Sender}; use dashmap::DashMap; use futures::Future; use message::Msg; use rand::RndGen; -use reorder_queue::ReorderQueue; +use reorder_queue::ReorderQueueInner; use task::{TaskCounter, TaskSpawnerWrapper}; use tokio::sync::{Notify, RwLock}; pub use builder::{Builder, DefaultBuilder, SharedBuilder}; pub use error::Error; pub use handler::Handler; -pub use message::{IntoMessage, Message}; -pub use producer::Producer; +pub use message::{IntoMessageStream, Message, MessageIterator, MessageStream, MessageTryIterator}; pub const DEFAUL_STREAM_ID: u32 = u32::MAX; pub const DEFAUL_TASK_ID: u32 = 0; #[derive(Default)] struct BusInner { - senders: DashMap<(u32, u32, TypeId), Arc>, - spawners: RwLock>>, - reordering: DashMap<(u32, TypeId), Box>, - counters: DashMap, + senders: DashMap<(u32, u32, TypeId), Arc>, + spawners: RwLock>>, + reordering: DashMap<(u32, TypeId), Arc>, + counters: DashMap<(u32, TypeId), Arc>, abort_notify: Arc, task_counter: Arc, + spawn_counter: Arc, + stopping: AtomicBool, + stopping_notify: Arc, rng: RndGen, } @@ -91,6 +92,7 @@ impl BusInner { stream_id: u32, task_id: u32, ) -> Result<(), Error> { + let type_id = TypeId::of::(); let sender = spawner .spawn_task( chan, @@ -98,21 +100,58 @@ impl BusInner { task_id, self.abort_notify.clone(), self.task_counter.clone(), + self.spawn_counter.clone(), + self.counters + .entry((stream_id, type_id)) + .or_default() + .value() + .clone(), self.clone(), ) .await?; - self.senders.insert( - (stream_id, task_id, TypeId::of::()), - Arc::new(sender) as _, - ); + self.senders + .insert((stream_id, task_id, type_id), Arc::new(sender) as _); Ok(()) } - async fn send_inner( + pub async fn send( self: &Arc, - msg: Result, + msg: Option>, + index: u64, + stream_id: u32, + reorder_buff: u32, + ) -> Result<(), Error> { + let type_id = TypeId::of::(); + + if type_id == TypeId::of::<()>() { + return Ok(()); + } + + let config = if let Some(spawner) = self + .spawners + .read() + .await + .get(&type_id) + .and_then(|x| x.downcast_ref::>()) + { + spawner.config(stream_id) + } else { + Config::default() + }; + + if reorder_buff > 1 { + self.send_ordered(msg, index, stream_id, config, reorder_buff) + .await + } else { + self.send_inner(msg, index, stream_id, config).await + } + } + + pub async fn send_inner( + self: &Arc, + msg: Option>, index: u64, stream_id: u32, config: Config, @@ -182,159 +221,132 @@ impl BusInner { Ok(()) } - pub async fn send(self: &Arc, msg: Msg) -> Result<(), Error> { - let type_id = TypeId::of::(); - if type_id == TypeId::of::<()>() { - return Ok(()); - } - - let stream_id = msg.stream_id; - - let config = if let Some(spawner) = self - .spawners - .read() - .await - .get(&type_id) - .and_then(|x| x.downcast_ref::>()) - { - spawner.config(stream_id) - } else { - Config::default() - }; - - if config.ordered { - let mut queue = self - .reordering - .entry((stream_id, type_id)) - .or_insert_with(|| { - Box::new(ReorderQueue::>>::new( - config.task_count as _, - )) - }); - - let queue = queue - .downcast_mut::>>>() - .unwrap(); - - if let Some(index) = queue.push(msg.index, Some(msg.inner)) { - self.send_inner( - Err::(Error::ReorderingMissedMessage(index)), - index, - stream_id, - config, - ) - .await?; - } - - while let Some(msg) = queue.pop() { - if let (index, Some(Some(msg))) = msg { - self.send_inner(msg, index, stream_id, config).await?; - } - } - - Ok(()) - } else { - self.send_inner(msg.inner, msg.index, stream_id, config) - .await - } - } - - pub async fn send_skip( + pub async fn send_ordered( self: &Arc, - stream_id: u32, + msg: Option>, index: u64, + stream_id: u32, + config: Config, + reorder_buff: u32, ) -> Result<(), Error> { let type_id = TypeId::of::(); - if type_id == TypeId::of::<()>() { - return Ok(()); - } - let config = if let Some(spawner) = self - .spawners - .read() - .await - .get(&type_id) - .and_then(|x| x.downcast_ref::>()) - { - spawner.config(stream_id) - } else { - Config::default() - }; + let sender = self + .reordering + .entry((stream_id, type_id)) + .or_insert_with(|| { + let (tx, rx) = kanal::bounded_async::>(4); - if config.ordered { - let mut queue = self - .reordering - .entry((stream_id, type_id)) - .or_insert_with(|| { - Box::new(ReorderQueue::>>::new( - config.task_count as _, - )) + let bus = self.clone(); + tokio::spawn(async move { + let mut queue = ReorderQueueInner::new(reorder_buff as _); + + while let Ok(msg) = rx.recv().await { + println!(" > {:?}", msg); + + if let Some(index) = queue.push(msg.index, msg.inner) { + bus.send_inner( + Some(Err::(Error::ReorderingDropMessage(index))), + index, + stream_id, + config, + ) + .await + .unwrap(); + } + + while let Some((index, msg)) = queue.pop() { + if let Err(err) = bus.send_inner(msg, index, stream_id, config).await { + println!("Err: {}", err); + } + } + } + + while let Some((index, msg)) = queue.force_pop() { + if let Err(err) = bus.send_inner(msg, index, stream_id, config).await { + println!("Err: {}", err); + } + } }); - let queue = queue - .downcast_mut::>>>() - .unwrap(); + Arc::new(tx) as _ + }) + .downgrade() + .clone(); - if let Some(index) = queue.push(index, None) { - self.send_inner( - Err::(Error::ReorderingMissedMessage(index)), - index, - stream_id, - config, - ) - .await?; - } - - while let Some(msg) = queue.pop() { - if let (index, Some(Some(msg))) = msg { - self.send_inner(msg, index, stream_id, config).await?; - } - } - } + sender + .upcast() + .downcast_ref::>>() + .unwrap() + .send(Msg { + inner: msg, + index, + stream_id, + }) + .await?; Ok(()) } - pub async fn register_dispatcher>(self: Arc, builder: B) + pub async fn register>(self: Arc, builder: B) where B::Context: Handler, { let type_id = TypeId::of::(); - self.counters.insert(type_id, AtomicU64::new(0)); self.spawners.write().await.insert( type_id, Box::new(TaskSpawnerWrapper::from_handler(builder)) as _, ); } - pub async fn register_producer>(self: Arc, builder: B) - where - B::Context: Producer, - { - self.spawners.write().await.insert( - TypeId::of::(), - Box::new(TaskSpawnerWrapper::from_producer(builder)) as _, - ); - } - #[inline] pub async fn close(&self, force: bool) { if force { self.abort_notify.notify_waiters(); } - for item in self.senders.iter() { - if force { - let _ = item.value().terminate(); - } else if item.is_producer() { - let _ = item.value().stop().await; - } - } + self.stopping.store(true, Ordering::Relaxed); + self.stopping_notify.notify_waiters(); } #[inline] pub async fn wait(&self) { - self.task_counter.wait().await; + while !self.stopping.load(Ordering::Relaxed) { + self.stopping_notify.notified().await; + } + + loop { + self.task_counter.wait().await; + + for queue in self.reordering.iter() { + queue.value().close(); + } + + if self.check_stopped() { + break; + } + } + + for sender in self.senders.iter() { + let _ = sender.value().stop().await; + } + + self.spawn_counter.wait().await; + } + + fn check_stopped(&self) -> bool { + println!("Checking Stopped:"); + for sender in self.senders.iter() { + let load = sender.value().load(); + println!("{:?}: {}/{}", sender.key(), load.0, load.1); + if sender.value().load().0 > 0 { + return false; + } + } + + println!("Checking Reordering Queues:"); + + true } } @@ -355,35 +367,13 @@ impl Bus { where B::Context: Handler, { - self.inner.clone().register_dispatcher(builder).await; - self - } - - #[inline] - pub async fn register_producer>(&self, builder: B) -> &Self - where - B::Context: Producer, - { - self.inner.clone().register_producer(builder).await; + self.inner.clone().register(builder).await; self } #[inline] pub async fn send(&self, inner: M) -> Result<(), Error> { - let index = self - .inner - .counters - .get(&TypeId::of::()) - .map(|x| x.fetch_add(1, Ordering::Relaxed)) - .unwrap_or(0); - - self.inner - .send(Msg { - inner: Ok(inner), - index, - stream_id: DEFAUL_STREAM_ID, - }) - .await + self.send_with_stream(DEFAUL_STREAM_ID, inner).await } #[inline] @@ -392,20 +382,16 @@ impl Bus { stream_id: u32, inner: M, ) -> Result<(), Error> { + let type_id = TypeId::of::(); let index = self .inner .counters - .get(&TypeId::of::()) - .map(|x| x.fetch_add(1, Ordering::Relaxed)) - .unwrap_or(0); + .entry((stream_id, type_id)) + .or_default() + .value() + .fetch_add(1, Ordering::Relaxed); - self.inner - .send(Msg { - inner: Ok(inner), - index, - stream_id, - }) - .await + self.inner.send(Some(Ok(inner)), index, stream_id, 0).await } /// @@ -435,41 +421,32 @@ mod tests { use std::{sync::Arc, time::Duration}; use async_stream::stream; - use futures::Stream; use rand::RngCore; use crate::{ - handler::Handler, producer::Producer, Bus, DefaultBuilder, Error, IntoMessage, Message, - SharedBuilder, + handler::Handler, Bus, DefaultBuilder, Error, IntoMessageStream, Message, SharedBuilder, }; impl Message for u64 {} impl Message for u32 {} + impl Message for i16 {} + impl Message for u16 {} #[derive(Default)] struct TestProducer; - impl Producer for TestProducer { - type Item = u64; + impl Handler for TestProducer { + type Result = u64; - fn stream( + async fn handle( &mut self, _msg: u32, _stream_id: u32, _task_id: u32, - ) -> impl Stream, Error>> + Send + '_ { - stream! { + ) -> Result, Error> { + Ok(crate::message::MessageStream(stream! { for i in 0u64..10 { yield Ok(i) } - } - } - - async fn handle_error( - &mut self, - _err: Error, - _stream_id: u32, - _task_id: u32, - ) -> Result, Error> { - Ok(None) + })) } async fn finalize(self) -> Result<(), Error> { @@ -478,10 +455,10 @@ mod tests { } } - struct TestConsumer(u32); + struct TestConsumer(u16); impl Default for TestConsumer { fn default() -> Self { - Self(rand::thread_rng().next_u32()) + Self(rand::thread_rng().next_u32() as _) } } @@ -493,7 +470,7 @@ mod tests { msg: u64, stream_id: u32, task_id: u32, - ) -> Result, Error> { + ) -> Result, Error> { tokio::time::sleep(Duration::from_millis(1000)).await; println!( "[{}] shared consumer handle {}u64 ({}:{})", @@ -501,14 +478,6 @@ mod tests { ); Ok(()) } - async fn handle_error( - &mut self, - _err: Error, - _stream_id: u32, - _task_id: u32, - ) -> Result, Error> { - Ok(None) - } async fn finalize(self) -> Result<(), Error> { println!("[{}] shared consumer finalized", self.0); @@ -524,8 +493,8 @@ mod tests { msg: u64, stream_id: u32, task_id: u32, - ) -> Result, Error> { - tokio::time::sleep(Duration::from_millis(1000)).await; + ) -> Result, Error> { + tokio::time::sleep(Duration::from_millis(100)).await; println!( "[{}] consumer handle {}u64 ({}:{})", self.0, msg, stream_id, task_id @@ -533,30 +502,85 @@ mod tests { Ok(()) } - async fn handle_error( - &mut self, - _err: Error, - _stream_id: u32, - _task_id: u32, - ) -> Result, Error> { - Ok(None) - } - async fn finalize(self) -> Result<(), Error> { println!("[{}] consumer finalized", self.0); Ok(()) } } + struct TestHandler {} + + impl Handler for Arc { + type Result = u16; + + async fn handle( + &mut self, + msg: i16, + _stream_id: u32, + task_id: u32, + ) -> Result, Error> { + if task_id % 2 == 0 { + tokio::time::sleep(Duration::from_millis(13)).await; + } else { + tokio::time::sleep(Duration::from_millis(22)).await; + } + + println!("handle {}", msg); + Ok(msg as u16) + } + + async fn finalize(self) -> Result<(), Error> { + Ok(()) + } + } + + #[derive(Default)] + struct TestCollector { + inner: Vec, + } + + impl Handler for TestCollector { + type Result = (); + + async fn handle( + &mut self, + msg: u16, + _stream_id: u32, + _task_id: u32, + ) -> Result, Error> { + println!("{}", msg); + self.inner.push(msg); + + Ok(None) + } + + async fn finalize(self) -> Result<(), Error> { + println!("Checking"); + assert_eq!(self.inner, (0u16..1024).collect::>()); + Ok(()) + } + + async fn handle_error( + &mut self, + err: Error, + _stream_id: u32, + _task_id: u32, + ) -> Result, Error> { + println!("{:?}", err); + + Ok(None) + } + } + #[tokio::test] - #[ignore = ""] + #[ignore] async fn test_streams() { let bus = Bus::default(); bus.register(DefaultBuilder::::new(2)) .await; - bus.register_producer(DefaultBuilder::::new(2)) + bus.register(DefaultBuilder::::new(2)) .await; for start in 0u32..10 { @@ -568,16 +592,16 @@ mod tests { } #[tokio::test] - #[ignore = ""] + #[ignore] async fn test_tasks_shared() { let bus = Bus::default(); - bus.register(SharedBuilder::new(2, 5, |sid, _tid| async move { - Ok(TestConsumer(sid)) + bus.register(SharedBuilder::new(2, 5, |_sid, _tid| async move { + Ok(TestConsumer::default()) })) .await; - bus.register_producer(DefaultBuilder::::new(2)) + bus.register(DefaultBuilder::::new(2)) .await; for start in 0u32..10 { @@ -587,4 +611,29 @@ mod tests { bus.close().await; bus.wait().await; } + + #[tokio::test] + // #[ignore] + async fn test_reordering() { + let bus = Bus::default(); + + bus.register( + SharedBuilder::new(4, 128, |sid, tid| async move { + println!("NEW HANDLER {}/{}", sid, tid); + Ok(TestHandler {}) + }) + .ordered(), + ) + .await; + + bus.register(DefaultBuilder::<_, TestCollector>::new(4)) + .await; + + for i in 0i16..1024 { + bus.send(i).await.unwrap(); + } + + bus.close().await; + bus.wait().await; + } } diff --git a/src/message.rs b/src/message.rs index c6ccf43..a883383 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,40 +1,70 @@ use core::fmt; +use std::any::Any; + +use futures::Stream; use crate::Error; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct Msg { - pub(crate) inner: Result, + pub(crate) inner: Option>, pub(crate) index: u64, pub(crate) stream_id: u32, } -// impl Msg { -// pub(crate) fn new(m: M, index: u64, stream_id: u32) -> Self { -// Self { -// inner: Ok(m), -// index, -// stream_id, -// } -// } -// } - -pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {} +pub struct MessageStream(pub S); +pub struct MessageIterator(pub I); +pub struct MessageTryIterator(pub I); +pub trait Message: Any + fmt::Debug + Clone + Send + Sync + 'static {} impl Message for () {} -pub trait IntoMessage: Send { - fn into_message(self) -> Option; +pub trait IntoMessageStream: Send { + fn into_message_stream(self) -> impl Stream> + Send; } -impl IntoMessage for Option { - fn into_message(self) -> Option { - self +impl> + Send> IntoMessageStream + for MessageStream +{ + fn into_message_stream(self) -> impl Stream> + Send { + self.0 } } -impl IntoMessage for M { - fn into_message(self) -> Option { - Some(self) +impl + Send> IntoMessageStream for MessageIterator { + fn into_message_stream(self) -> impl Stream> { + futures::stream::iter(self.0.map(Ok)) + } +} + +impl> + Send> IntoMessageStream + for MessageTryIterator +{ + fn into_message_stream(self) -> impl Stream> + Send { + futures::stream::iter(self.0) + } +} + +impl IntoMessageStream for Option { + fn into_message_stream(self) -> impl Stream> + Send { + futures::stream::iter(self.into_iter().map(Ok)) + } +} + +impl IntoMessageStream for Vec { + fn into_message_stream(self) -> impl Stream> + Send { + futures::stream::iter(self.into_iter().map(Ok)) + } +} + +impl IntoMessageStream for [M; N] { + fn into_message_stream(self) -> impl Stream> + Send { + futures::stream::iter(self.map(Ok)) + } +} + +impl IntoMessageStream for M { + fn into_message_stream(self) -> impl Stream> + Send { + futures::stream::iter([Ok(self)]) } } diff --git a/src/producer.rs b/src/producer.rs deleted file mode 100644 index cbf89c9..0000000 --- a/src/producer.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::{ - marker::PhantomData, - pin::{pin, Pin}, - sync::Arc, -}; - -use futures::{Future, Stream, StreamExt}; -use tokio::sync::Notify; - -use crate::{ - builder::Builder, - chan::Receiver, - message::Msg, - task::{TaskCounter, TaskSpawner}, - BusInner, Error, IntoMessage, Message, -}; - -pub trait Producer: Send + Sync + 'static { - type Item: Message; - - fn stream( - &mut self, - msg: M, - stream_id: u32, - task_id: u32, - ) -> impl Stream, Error>> + Send + '_; - - fn handle_error( - &mut self, - err: Error, - stream_id: u32, - task_id: u32, - ) -> impl Future, Error>> + Send + '_; - - fn finalize(self) -> impl Future> + Send; -} - -pub(crate) struct ProducerSpawner { - pub(crate) builder: B, - _m: PhantomData, -} - -impl ProducerSpawner { - pub(crate) fn new(builder: B) -> Self { - Self { - builder, - _m: PhantomData, - } - } -} - -impl> TaskSpawner for ProducerSpawner -where - B::Context: Producer, -{ - fn spawn_task( - &self, - rx: Receiver>, - stream_id: u32, - task_id: u32, - abort: Arc, - task_counter: Arc, - bus: Arc, - ) -> Pin> + Send + '_>> { - Box::pin(async move { - let mut ctx = self.builder.build(stream_id, task_id).await?; - - let _handle = tokio::spawn(async move { - while let Some(recv_msg) = rx.recv().await { - task_counter.inc_running(); - - match recv_msg.inner { - Ok(msg) => { - let mut stream = pin!(ctx - .stream(msg, stream_id, task_id) - .take_until(abort.notified())); - - let mut index = 0; - - while let Some(res) = stream.next().await { - if let Some(inner) = res.map(IntoMessage::into_message).transpose() - { - index += 1; - if let Err(err) = bus - .send(Msg { - inner, - index: index - 1, - stream_id, - }) - .await - { - println!("BUS SEND ERROR: {:?}", err); - continue; - } - } - } - } - Err(err) => { - let Some(inner) = ctx - .handle_error(err, stream_id, task_id) - .await - .map(IntoMessage::into_message) - .transpose() - else { - if let Err(err) = - bus.send_skip::(stream_id, recv_msg.index).await - { - println!("BUS SEND ERROR: {:?}", err); - } - continue; - }; - - if let Err(err) = bus - .send(Msg { - inner, - index: recv_msg.index, - stream_id, - }) - .await - { - println!("BUS SEND ERROR: {:?}", err); - } - } - } - - task_counter.dec_running(rx.is_empty()); - } - - if let Err(err) = ctx.finalize().await { - println!("TASK FINALIZE ERROR: {:?}", err); - } - }); - - Ok(()) - }) - } - - fn is_producer(&self) -> bool { - true - } - - fn config(&self, stream_id: u32) -> crate::builder::Config { - self.builder.config(stream_id) - } -} diff --git a/src/reorder_queue.rs b/src/reorder_queue.rs index e8ff3fd..dd9190e 100644 --- a/src/reorder_queue.rs +++ b/src/reorder_queue.rs @@ -1,7 +1,7 @@ use std::{cmp::Ordering, collections::BinaryHeap}; struct Entry { - inner: Option, + inner: M, index: u64, } @@ -25,13 +25,13 @@ impl PartialEq for Entry { impl Eq for Entry {} -pub(crate) struct ReorderQueue { +pub(crate) struct ReorderQueueInner { cap: usize, recent_index: Option, heap: BinaryHeap>, } -impl ReorderQueue { +impl ReorderQueueInner { pub fn new(cap: usize) -> Self { Self { cap, @@ -40,11 +40,8 @@ impl ReorderQueue { } } - pub fn push(&mut self, index: u64, msg: M) -> Option { - self.heap.push(Entry { - inner: Some(msg), - index, - }); + pub fn push(&mut self, index: u64, inner: M) -> Option { + self.heap.push(Entry { inner, index }); if self.heap.len() == self.cap { self.recent_index = self.recent_index.map(|x| x + 1); @@ -54,7 +51,7 @@ impl ReorderQueue { } } - pub fn pop(&mut self) -> Option<(u64, Option)> { + pub fn pop(&mut self) -> Option<(u64, M)> { match self.recent_index { None => { let e = self.heap.pop()?; @@ -72,22 +69,28 @@ impl ReorderQueue { } } } + + pub fn force_pop(&mut self) -> Option<(u64, M)> { + let e = self.heap.pop()?; + self.recent_index = Some(e.index); + Some((e.index, e.inner)) + } } #[cfg(test)] mod tests { use crate::Message; - use super::ReorderQueue; + use super::ReorderQueueInner; impl Message for i32 {} #[test] fn test_reordering() { - let mut queue = ReorderQueue::new(8); + let mut queue = ReorderQueueInner::new(8); assert_eq!(queue.push(0, 0), None); - assert_eq!(queue.pop(), Some((0, Some(0)))); + assert_eq!(queue.pop(), Some((0, 0))); assert_eq!(queue.pop(), None); assert_eq!(queue.push(3, 3), None); @@ -100,18 +103,18 @@ mod tests { assert_eq!(queue.pop(), None); assert_eq!(queue.push(1, 1), None); - assert_eq!(queue.pop(), Some((1, Some(1)))); - assert_eq!(queue.pop(), Some((2, Some(2)))); - assert_eq!(queue.pop(), Some((3, Some(3)))); - assert_eq!(queue.pop(), Some((4, Some(4)))); + assert_eq!(queue.pop(), Some((1, 1))); + assert_eq!(queue.pop(), Some((2, 2))); + assert_eq!(queue.pop(), Some((3, 3))); + assert_eq!(queue.pop(), Some((4, 4))); assert_eq!(queue.pop(), None); } #[test] fn test_overflow() { - let mut queue = ReorderQueue::new(4); + let mut queue = ReorderQueueInner::new(4); assert_eq!(queue.push(0, 0), None); - assert_eq!(queue.pop(), Some((0, Some(0)))); + assert_eq!(queue.pop(), Some((0, 0))); assert_eq!(queue.pop(), None); assert_eq!(queue.push(4, 4), None); @@ -124,10 +127,10 @@ mod tests { assert_eq!(queue.pop(), None); assert_eq!(queue.push(5, 5), Some(1)); - assert_eq!(queue.pop(), Some((2, Some(2)))); - assert_eq!(queue.pop(), Some((3, Some(3)))); - assert_eq!(queue.pop(), Some((4, Some(4)))); - assert_eq!(queue.pop(), Some((5, Some(5)))); + assert_eq!(queue.pop(), Some((2, 2))); + assert_eq!(queue.pop(), Some((3, 3))); + assert_eq!(queue.pop(), Some((4, 4))); + assert_eq!(queue.pop(), Some((5, 5))); assert_eq!(queue.pop(), None); } } diff --git a/src/task.rs b/src/task.rs index 3106db0..b5453b4 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,7 +1,7 @@ use std::{ pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicU64, AtomicUsize, Ordering}, Arc, }, }; @@ -14,29 +14,46 @@ use crate::{ chan::{BusSender, Receiver, Sender}, handler::HandlerSpawner, message::Msg, - producer::ProducerSpawner, - Builder, BusInner, Error, Handler, Message, Producer, + Builder, BusInner, Error, Handler, Message, }; #[derive(Default)] pub(crate) struct TaskCounter { - running: AtomicUsize, + pub(crate) running: AtomicUsize, notify: Notify, } -impl TaskCounter { - #[inline] - pub fn inc_running(&self) { - self.running.fetch_add(1, Ordering::Relaxed); - } +pub(crate) struct TaskCounterLease bool> { + need_notify: S, + counter: Arc, +} + +impl bool> Drop for TaskCounterLease { + fn drop(&mut self) { + let notify = (self.need_notify)(); + let prev = self.counter.running.fetch_sub(1, Ordering::Relaxed); - #[inline] - pub fn dec_running(&self, notify: bool) { - let prev = self.running.fetch_sub(1, Ordering::Relaxed); if notify && prev == 1 { - self.notify.notify_waiters(); + self.counter.notify.notify_waiters(); } } +} + +impl bool> TaskCounterLease { + fn new(counter: Arc, need_notify: S) -> Self { + counter.running.fetch_add(1, Ordering::Relaxed); + + Self { + counter, + need_notify, + } + } +} + +impl TaskCounter { + pub fn lease_unit bool>(self: Arc, need_notify: S) -> TaskCounterLease { + TaskCounterLease::new(self, need_notify) + } #[inline] pub async fn wait(&self) { @@ -47,6 +64,8 @@ impl TaskCounter { pub(crate) trait TaskSpawner: Send + Sync { fn config(&self, stream_id: u32) -> Config; fn is_producer(&self) -> bool; + + #[allow(clippy::too_many_arguments)] fn spawn_task( &self, rx: Receiver>, @@ -54,6 +73,8 @@ pub(crate) trait TaskSpawner: Send + Sync { task_id: u32, abort: Arc, task_counter: Arc, + spawn_counter: Arc, + index_counter: Arc, bus: Arc, ) -> Pin> + Send + '_>>; } @@ -80,16 +101,8 @@ impl TaskSpawnerWrapper { } } - pub fn from_producer + 'static>(builder: B) -> Self - where - B::Context: Producer, - { - Self { - inner: Arc::new(ProducerSpawner::new(builder)) as _, - } - } - #[inline] + #[allow(clippy::too_many_arguments)] pub async fn spawn_task( &self, (tx, rx): (Sender>, Receiver>), @@ -97,10 +110,21 @@ impl TaskSpawnerWrapper { task_id: u32, abort: Arc, task_counter: Arc, + spawn_counter: Arc, + index_counter: Arc, bus: Arc, ) -> Result, Error> { self.inner - .spawn_task(rx, stream_id, task_id, abort, task_counter, bus) + .spawn_task( + rx, + stream_id, + task_id, + abort, + task_counter, + spawn_counter, + index_counter, + bus, + ) .await?; Ok(BusSender::new(self.inner.is_producer(), tx))