diff --git a/Cargo.toml b/Cargo.toml index 23e0554..1fdc18b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,11 @@ async-trait = "0.1.42" futures = "0.3.8" anyhow = "1.0.34" crossbeam = "0.8.1" +uuid = "0.8.2" +tokio-util = "0.6.7" +async-stream = "0.3.2" +smallvec = "1.6.1" +log = "0.4.14" [dev-dependencies] tokio = { version = "1", features = ["macros", "parking_lot", "rt-multi-thread", "io-util", "sync"] } diff --git a/examples/demo_async.rs b/examples/demo_async.rs index 6330cbb..dcbd72c 100644 --- a/examples/demo_async.rs +++ b/examples/demo_async.rs @@ -1,15 +1,24 @@ use async_trait::async_trait; -use messagebus::{receivers, AsyncHandler, Bus, Handler, Result as MbusResult}; +use messagebus::{receivers, AsyncHandler, Bus, Handler}; struct TmpReceiver; struct TmpReceiver2; #[async_trait] impl AsyncHandler for TmpReceiver { - async fn handle(&self, msg: f32, bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: f32, bus: &Bus) -> Result { bus.send(1u16).await?; - println!("---> f32 {}", msg); + println!("TmpReceiver ---> f32 {}", msg); + + Ok(()) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver f32: sync"); Ok(()) } @@ -17,9 +26,18 @@ impl AsyncHandler for TmpReceiver { #[async_trait] impl AsyncHandler for TmpReceiver { - async fn handle(&self, msg: u16, bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: u16, bus: &Bus) -> Result { bus.send(2u32).await?; - println!("---> u16 {}", msg); + println!("TmpReceiver ---> u16 {}", msg); + + Ok(()) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver u16: sync"); Ok(()) } @@ -27,9 +45,17 @@ impl AsyncHandler for TmpReceiver { #[async_trait] impl AsyncHandler for TmpReceiver { - async fn handle(&self, msg: u32, bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: u32, bus: &Bus) -> Result { bus.send(3i32).await?; - println!("---> u32 {}", msg); + println!("TmpReceiver ---> u32 {}", msg); + + Ok(()) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver u32: sync"); Ok(()) } @@ -37,9 +63,18 @@ impl AsyncHandler for TmpReceiver { #[async_trait] impl AsyncHandler for TmpReceiver { - async fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: i32, bus: &Bus) -> Result { bus.send(4i16).await?; - println!("---> i32 {}", msg); + println!("TmpReceiver ---> i32 {}", msg); + + Ok(()) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver i32: sync"); Ok(()) } @@ -47,8 +82,16 @@ impl AsyncHandler for TmpReceiver { #[async_trait] impl AsyncHandler for TmpReceiver { - async fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { - println!("---> i16 {}", msg); + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: i16, _bus: &Bus) -> Result { + println!("TmpReceiver ---> i16 {}", msg); + + Ok(()) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver i16: sync"); Ok(()) } @@ -56,17 +99,35 @@ impl AsyncHandler for TmpReceiver { #[async_trait] impl AsyncHandler for TmpReceiver2 { - async fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: i32, bus: &Bus) -> Result { + println!("!!!! TmpReceiver2: ---> 2 i32 {}", msg); + bus.send(5i16).await?; - println!("---> 2 i32 {}", msg); + + Ok(()) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver2: i32: sync"); Ok(()) } } impl Handler for TmpReceiver2 { - fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { - println!("---> 2 i16 {}", msg); + type Error = anyhow::Error; + type Response = (); + + fn handle(&self, msg: i16, _bus: &Bus) -> Result { + println!("TmpReceiver2: ---> 2 i16 {}", msg); + + Ok(()) + } + + fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver2: i16: sync"); Ok(()) } @@ -76,18 +137,26 @@ impl Handler for TmpReceiver2 { async fn main() { let (b, poller) = Bus::build() .register(TmpReceiver) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) .done() .register(TmpReceiver2) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) .done() - .build(); + .build(); b.send(0f32).await.unwrap(); - poller.await + + println!("flush"); + b.flush().await; + + println!("close"); + b.close().await; + + poller.await; + println!("[done]"); } diff --git a/examples/demo_backpressure.rs b/examples/demo_backpressure.rs index 8fe85a3..7c21799 100644 --- a/examples/demo_backpressure.rs +++ b/examples/demo_backpressure.rs @@ -1,9 +1,14 @@ -use messagebus::{receivers, Bus, Handler, Result as MbusResult}; +use async_trait::async_trait; +use messagebus::{receivers, Bus, AsyncHandler}; struct TmpReceiver; -impl Handler for TmpReceiver { - fn handle(&self, msg: f32, _bus: &Bus) -> MbusResult { +#[async_trait] +impl AsyncHandler for TmpReceiver { + type Error = anyhow::Error; + type Response = (); + + async fn handle(&self, msg: f32, _bus: &Bus) -> Result { println!("---> f32 {}", msg); std::thread::sleep(std::time::Duration::from_secs(1)); @@ -17,7 +22,7 @@ impl Handler for TmpReceiver { async fn main() { let (b, poller) = Bus::build() .register(TmpReceiver) - .subscribe::>(receivers::BufferUnorderedConfig { + .subscribe::, _, _>(1, receivers::BufferUnorderedConfig { buffer_size: 1, max_parallel: 1, }) @@ -57,7 +62,12 @@ async fn main() { println!("sending 11"); b.send(32f32).await.unwrap(); - println!("finish"); + println!("flush"); + b.flush().await; + + println!("close"); + b.close().await; poller.await; + println!("[done]"); } diff --git a/examples/demo_slow.rs b/examples/demo_slow.rs index ffcedaa..9031010 100644 --- a/examples/demo_slow.rs +++ b/examples/demo_slow.rs @@ -1,9 +1,12 @@ -use messagebus::{receivers, Bus, Handler, Result as MbusResult}; +use messagebus::{receivers, Bus, Handler}; struct TmpReceiver; impl Handler for TmpReceiver { - fn handle(&self, msg: f32, _bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + fn handle(&self, msg: f32, _bus: &Bus) -> Result { println!("---> f32 {}", msg); std::thread::sleep(std::time::Duration::from_secs(5)); @@ -15,14 +18,20 @@ impl Handler for TmpReceiver { } impl Handler for TmpReceiver { - fn handle(&self, msg: u16, _bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + fn handle(&self, msg: u16, _bus: &Bus) -> Result { println!("---> u16 {}", msg); Ok(()) } } impl Handler for TmpReceiver { - fn handle(&self, msg: u32, _bus: &Bus) -> MbusResult { + type Error = anyhow::Error; + type Response = (); + + fn handle(&self, msg: u32, _bus: &Bus) -> Result { println!("---> u32 {}", msg); Ok(()) } @@ -32,9 +41,9 @@ impl Handler for TmpReceiver { async fn main() { let (b, poller) = Bus::build() .register(TmpReceiver) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) .done() .build(); @@ -42,5 +51,12 @@ async fn main() { b.send(11u16).await.unwrap(); b.send(32u32).await.unwrap(); - poller.await + println!("flush"); + b.flush().await; + + println!("close"); + b.close().await; + + poller.await; + println!("[done]"); } diff --git a/src/builder.rs b/src/builder.rs index c9237bf..f7c4b9e 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -20,7 +20,13 @@ pub trait ReceiverSubscriber { ); } -pub trait ReceiverSubscriberBuilder { +pub trait ReceiverSubscriberBuilder + where + T: 'static, + M: Message, + R: Message, + E: crate::Error +{ type Entry: ReceiverSubscriber; type Config: Default; @@ -44,6 +50,7 @@ pub struct RegisterEntry { ) -> Box Pin + Send>>>, >, + Box Pin + Send>>> )>, >, _m: PhantomData<(K, T)>, @@ -54,10 +61,10 @@ impl RegisterEntry { let mut builder = self.builder; for (tid, v) in self.receivers { - for (r, poller) in v { + for (r, poller, poller2) in v { let poller = poller(self.item.clone()); - builder.add_recevier((tid, r), poller); + builder.add_recevier((tid, r), poller, poller2); } } @@ -65,39 +72,45 @@ impl RegisterEntry { } } -impl RegisterEntry { - pub fn subscribe(mut self, cfg: R::Config) -> Self +impl RegisterEntry { + pub fn subscribe(mut self, queue: u64, cfg: S::Config) -> Self where T: Send + 'static, - M: Message + 'static, - R: ReceiverSubscriberBuilder + 'static, + M: Message, + R: Message, + E: crate::Error, + S: ReceiverSubscriberBuilder + 'static, { - let (inner, poller) = R::build(cfg).subscribe(); + let (inner, poller) = S::build(cfg).subscribe(); - let receiver = Receiver::new(inner); + let receiver = Receiver::new(queue, inner); + let poller2 = receiver.start_polling_events::(); self.receivers .entry(TypeId::of::()) .or_insert_with(Vec::new) - .push((receiver, poller)); + .push((receiver, poller, poller2)); self } } -impl RegisterEntry { - pub fn subscribe(mut self, cfg: R::Config) -> Self +impl RegisterEntry { + pub fn subscribe(mut self, queue: u64, cfg: S::Config) -> Self where - T: Send + 'static, - M: Message + 'static, - R: ReceiverSubscriberBuilder + 'static, + T: Send + Sync + 'static, + M: Message, + R: Message, + E: crate::Error, + S: ReceiverSubscriberBuilder + 'static, { - let (inner, poller) = R::build(cfg).subscribe(); + let (inner, poller) = S::build(cfg).subscribe(); - let receiver = Receiver::new(inner); + let receiver = Receiver::new(queue, inner); + let poller2 = receiver.start_polling_events::(); self.receivers .entry(TypeId::of::()) .or_insert_with(Vec::new) - .push((receiver, poller)); + .push((receiver, poller, poller2)); self } @@ -138,9 +151,11 @@ impl BusBuilder { &mut self, val: (TypeId, Receiver), poller: Box Pin + Send>>>, + poller2: Box Pin + Send>>>, ) { self.receivers.push(val); self.pollings.push(poller); + self.pollings.push(poller2); } pub fn build(self) -> (Bus, impl Future) { @@ -148,7 +163,7 @@ impl BusBuilder { inner: Arc::new(BusInner::new(self.receivers)), }; - let mut futs = Vec::with_capacity(self.pollings.len()); + let mut futs = Vec::with_capacity(self.pollings.len() * 2); for poller in self.pollings { futs.push(tokio::task::spawn(poller(bus.clone()))); } diff --git a/src/handler.rs b/src/handler.rs index 7facd70..26b15ec 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -2,91 +2,128 @@ use crate::{Bus, Message}; use async_trait::async_trait; pub trait Handler: Send + Sync { - fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; - fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&self, msg: M, bus: &Bus) -> Result; + fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait AsyncHandler: Send + Sync { - async fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&self, msg: M, bus: &Bus) -> Result; + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } pub trait SynchronizedHandler: Send { - fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; - fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&mut self, msg: M, bus: &Bus) -> Result; + fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait AsyncSynchronizedHandler: Send { - async fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&mut self, msg: M, bus: &Bus) -> Result; + async fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } pub trait BatchHandler: Send + Sync { - fn handle(&self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait AsyncBatchHandler: Send + Sync { - async fn handle(&self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } pub trait BatchSynchronizedHandler: Send { - fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&mut self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait AsyncBatchSynchronizedHandler: Send { - async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&mut self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + async fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } pub trait LocalHandler { - fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&mut self, msg: Vec, bus: &Bus) -> Result; + fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait LocalAsyncHandler { - async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&mut self, msg: Vec, bus: &Bus) -> Result; + async fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } + pub trait LocalBatchHandler { - fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + fn handle(&mut self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } #[async_trait] pub trait LocalAsyncBatchHandler { - async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; - async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + type Error: crate::Error; + type Response: Message; + + async fn handle(&mut self, msg: Vec, bus: &Bus) -> Result, Self::Error>; + async fn sync(&mut self, _bus: &Bus) -> Result<(), Self::Error> { Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index ac7dfe7..b295bad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,32 +5,44 @@ pub mod msgs; mod receiver; pub mod receivers; mod trait_object; -mod utils; + +#[macro_use] +extern crate log; use builder::BusBuilder; pub use envelop::Message; pub use handler::*; pub use receiver::SendError; use receiver::{Receiver, ReceiverStats}; -use utils::binary_search_range_by_key; +use smallvec::SmallVec; use core::any::{Any, TypeId}; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; +use std::{collections::HashMap, sync::{Arc, atomic::{AtomicBool, AtomicU64, Ordering}}}; + +use crate::receivers::Permit; pub type Untyped = Arc; -pub type Result = anyhow::Result<()>; + +// pub trait ErrorTrait: std::error::Error + Send + Sync + 'static {} +pub trait Error: Into + Send + Sync + 'static {} +impl + Send + Sync + 'static> Error for T {} + +static ID_COUNTER: AtomicU64 = AtomicU64::new(1); pub struct BusInner { - receivers: Vec<(TypeId, Receiver)>, + receivers: HashMap>, closed: AtomicBool, } impl BusInner { - pub(crate) fn new(mut receivers: Vec<(TypeId, Receiver)>) -> Self { - receivers.sort_unstable_by_key(|(k, _)| *k); + pub(crate) fn new(input: Vec<(TypeId, Receiver)>) -> Self { + let mut receivers = HashMap::new(); + + for (key, value) in input { + receivers.entry(key) + .or_insert_with(SmallVec::new) + .push(value); + } Self { receivers, @@ -38,42 +50,101 @@ impl BusInner { } } - pub fn close(&self) { + pub async fn close(&self) { self.closed.store(true, Ordering::SeqCst); - for (_, r) in &self.receivers { - r.close(); + for (_, rs) in &self.receivers { + for r in rs { + r.close().await; + } } } - pub async fn sync(&self) { - for (_, r) in &self.receivers { - r.sync().await; + pub async fn flush(&self) { + let fuse_count = 32i32; + let mut breaked = false; + let mut iters = 0usize; + for _ in 0..fuse_count { + iters += 1; + let mut flushed = false; + for (_, rs) in &self.receivers { + + for r in rs { + if r.need_flush() { + flushed = true; + r.flush().await; + } + } + } + + if !flushed { + breaked = true; + break; + } + } + + if !breaked { + warn!("!!! WARNING: unable to reach equilibrium in {} iterations !!!", fuse_count); + } else { + info!("flushed in {} iterations !!!", iters); + } + } + + pub async fn flash_and_sync(&self) { + self.flush().await; + + for (_, rs) in &self.receivers { + for r in rs { + r.sync().await; + } } } - pub fn stats(&self) -> impl Iterator + '_ { - self.receivers.iter().map(|(_, r)| r.stats()) - } + // pub fn stats(&self) -> impl Iterator + '_ { + // self.receivers.iter() + // .map(|(_, i)|i.iter()) + // .flatten() + // .map(|r| r.stats()) + // } pub fn try_send(&self, msg: M) -> core::result::Result<(), SendError> { if self.closed.load(Ordering::SeqCst) { - println!("Bus closed. Skipping send!"); + warn!("Bus closed. Skipping send!"); return Ok(()); } + let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); let tid = TypeId::of::(); - let range = binary_search_range_by_key(&self.receivers, &tid, |(k, _)| *k); - for i in (range.start + 1)..range.end { - self.receivers[i].1.try_broadcast(msg.clone())?; + if let Some(rs) = self.receivers.get(&tid) { + let mut permits = SmallVec::<[Permit; 32]>::new(); + + for r in rs { + if let Some(prmt) = r.try_reserve() { + permits.push(prmt); + } else { + return Err(SendError::Full(msg)); + }; + } + + let mut iter = permits.into_iter().zip(rs.iter()); + let mut counter = 1; + let total = rs.len(); + + while counter < total { + let (p, r) = iter.next().unwrap(); + let _ = r.send(mid, p, msg.clone()); + + counter += 1; + } + + if let Some((p, r)) = iter.next() { + let _ = r.send(mid, p, msg); + return Ok(()); + } } - if let Some((_, r)) = self.receivers.get(range.start) { - r.try_broadcast(msg.clone())?; - } else { - println!("Unhandled message {:?}", core::any::type_name::()); - } + warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); Ok(()) } @@ -88,19 +159,48 @@ impl BusInner { return Err(SendError::Closed(msg)); } + let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); let tid = TypeId::of::(); - let range = binary_search_range_by_key(&self.receivers, &tid, |(k, _)| *k); - for i in (range.start + 1)..range.end { - self.receivers[i].1.broadcast(msg.clone()).await?; + if let Some(rs) = self.receivers.get(&tid) { + if let Some((last, head)) = rs.split_last() { + for r in head { + let _ = r.send(mid, r.reserve().await, msg.clone()); + } + + let _ = last.send(mid, last.reserve().await, msg.clone()); + + return Ok(()); + } } - if let Some((_, r)) = self.receivers.get(range.start) { - r.broadcast(msg.clone()).await?; - } else { - println!("Unhandled message {:?}", core::any::type_name::()); + warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); + + Ok(()) + } + + pub async fn force_send(&self, msg: M) -> core::result::Result<(), SendError> { + if self.closed.load(Ordering::SeqCst) { + return Err(SendError::Closed(msg)); } + let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); + let tid = TypeId::of::(); + + if let Some(rs) = self.receivers.get(&tid) { + if let Some((last, head)) = rs.split_last() { + for r in head { + let _ = r.force_send(mid, msg.clone()); + } + + let _ = last.force_send(mid, msg.clone()); + + return Ok(()); + } + } + + warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); + Ok(()) } } diff --git a/src/receiver.rs b/src/receiver.rs index b5d7907..d2ae4e0 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -1,15 +1,10 @@ -use crate::{trait_object::TraitObject, Bus, Message}; -use core::{ - any::TypeId, - fmt, - future::Future, - marker::PhantomData, - mem, - pin::Pin, - task::{Context, Poll}, -}; +use crate::{Bus, Message, msgs, receivers::{Event, Permit, PermitDrop}, trait_object::TraitObject}; +use core::{any::TypeId, fmt, marker::PhantomData, mem, pin::Pin, task::{Context, Poll}}; use futures::future::poll_fn; -use std::{borrow::Cow, sync::Arc}; +use tokio::sync::Notify; +use std::{borrow::Cow, sync::{Arc, atomic::{AtomicBool, AtomicU64, Ordering}}}; +use futures::Future; + pub struct AnyReceiver<'a> { dyn_typed_receiver_trait_object: TraitObject, @@ -17,19 +12,52 @@ pub struct AnyReceiver<'a> { _m: PhantomData<&'a usize>, } +unsafe impl Send for AnyReceiver<'_> {} + impl<'a> AnyReceiver<'a> { - pub fn new + 'static>(rcvr: &'a R) -> Self { - let trcvr = rcvr as &(dyn TypedReceiver); + pub fn new + 'static>(rcvr: &'a R) -> Self { + let trcvr = rcvr as &(dyn SendTypedReceiver); Self { dyn_typed_receiver_trait_object: unsafe { mem::transmute(trcvr) }, - type_id: TypeId::of::>(), + type_id: TypeId::of::>(), _m: Default::default(), } } - pub fn dyn_typed_receiver(&'a self) -> &'a dyn TypedReceiver { - assert_eq!(self.type_id, TypeId::of::>()); + pub fn dyn_typed_receiver(&'a self) -> &'a dyn SendTypedReceiver { + assert_eq!(self.type_id, TypeId::of::>()); + + unsafe { mem::transmute(self.dyn_typed_receiver_trait_object) } + } +} + +pub struct AnyPoller<'a> { + dyn_typed_receiver_trait_object: TraitObject, + type_id: TypeId, + _m: PhantomData<&'a usize>, +} + +unsafe impl Send for AnyPoller<'_> {} + +impl<'a> AnyPoller<'a> { + pub fn new(rcvr: &'a R) -> Self + where + M: Message, + E: crate::Error, + R: ReciveTypedReceiver + 'static + { + let trcvr = rcvr as &(dyn ReciveTypedReceiver); + + Self { + dyn_typed_receiver_trait_object: unsafe { mem::transmute(trcvr) }, + type_id: TypeId::of::>(), + _m: Default::default(), + } + } + + pub fn dyn_typed_receiver(&'a self) -> &'a dyn ReciveTypedReceiver { + assert_eq!(self.type_id, TypeId::of::>()); unsafe { mem::transmute(self.dyn_typed_receiver_trait_object) } } @@ -87,26 +115,51 @@ impl fmt::Display for ReceiverStats { } } -pub trait TypedReceiver: Sync { - fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()>; - fn try_send(&self, msg: M) -> Result<(), SendError>; +pub trait SendTypedReceiver: Sync { + fn send(&self, mid: u64, msg: M) -> Result<(), SendError>; +} + +pub trait ReciveTypedReceiver: Sync + where M: Message, + E: crate::Error +{ + fn poll_events(&self, ctx: &mut Context<'_>) -> Poll>; } pub trait ReceiverTrait: Send + Sync { fn typed(&self) -> AnyReceiver<'_>; + fn poller(&self) -> AnyPoller<'_>; fn type_id(&self) -> TypeId; - fn close(&self); - fn stats(&self) -> ReceiverStats; - fn sync(&self); - fn poll_synchronized(&self, ctx: &mut Context<'_>) -> Poll<()>; + fn stats(&self) -> Result<(), SendError<()>>; + fn close(&self) -> Result<(), SendError<()>>; + fn sync(&self) -> Result<(), SendError<()>>; + fn flush(&self) -> Result<(), SendError<()>>; } pub trait ReceiverPollerBuilder { fn build(bus: Bus) -> Box>; } +struct ReceiverContext { + limit: u64, + processing: AtomicU64, + need_flush: AtomicBool, + flushed: Notify, + synchronized: Notify, + closed: Notify, + response: Notify, + statistics: Notify, +} + +impl PermitDrop for ReceiverContext { + fn permit_drop(&self) { + self.processing.fetch_sub(1, Ordering::SeqCst); + } +} + pub struct Receiver { inner: Arc, + context: Arc, } impl fmt::Debug for Receiver { @@ -124,49 +177,21 @@ impl core::cmp::PartialEq for Receiver { impl core::cmp::Eq for Receiver {} -pub struct ReceiverPoller<'a, M: Message> { - inner: &'a dyn ReceiverTrait, - msg: Option, -} - -impl<'a, M: Message> Unpin for ReceiverPoller<'a, M> {} -impl<'a, M: Message> Future for ReceiverPoller<'a, M> { - type Output = Result<(), SendError>; - - fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - let mut msg = if let Some(msg) = this.msg.take() { - msg - } else { - return Poll::Ready(Ok(())); - }; - - let any_receiver = this.inner.typed(); - let receiver = any_receiver.dyn_typed_receiver::(); - - loop { - match receiver.poll_ready(ctx) { - Poll::Ready(_) => (), - Poll::Pending => { - this.msg = Some(msg); - return Poll::Pending; - } - } - - msg = match receiver.try_send(msg) { - Ok(_) => break Poll::Ready(Ok(())), - Err(SendError::Full(m)) => m, - Err(err) => break Poll::Ready(Err(err)), - } - } - } -} - impl Receiver { #[inline] - pub(crate) fn new(inner: Arc) -> Self { - Self { inner } + pub(crate) fn new(limit: u64, inner: Arc) -> Self { + let context = Arc::new(ReceiverContext { + limit, + processing: AtomicU64::new(0), + need_flush: AtomicBool::new(false), + flushed: Notify::new(), + synchronized: Notify::new(), + closed: Notify::new(), + response: Notify::new(), + statistics: Notify::new(), + }); + + Self { inner, context } } #[inline] @@ -175,38 +200,157 @@ impl Receiver { } #[inline] - pub fn broadcast( - &self, - msg: M, - ) -> impl Future>> + '_ { - ReceiverPoller { - inner: self.inner.as_ref(), - msg: Some(msg), + pub fn need_flush(&self) -> bool { + self.context.need_flush.load(Ordering::SeqCst) + } + + #[inline] + pub async fn reserve(&self) -> Permit { + loop { + let count = self.context.processing.load(Ordering::Relaxed); + if count < self.context.limit { + let res = self.context.processing.compare_exchange(count, count + 1, Ordering::SeqCst, Ordering::SeqCst); + if res.is_ok() { + break Permit { + fuse: false, + inner: self.context.clone(), + }; + } + + // continue + } else { + self.context.response.notified() + .await + } } } #[inline] - pub fn try_broadcast(&self, msg: M) -> Result<(), SendError> { + pub fn try_reserve(&self) -> Option { + loop { + let count = self.context.processing.load(Ordering::Relaxed); + + if count < self.context.limit { + let res = self.context.processing.compare_exchange(count, count + 1, Ordering::SeqCst, Ordering::SeqCst); + if res.is_ok() { + break Some(Permit { + fuse: false, + inner: self.context.clone(), + }); + } + + // continue + } else { + break None; + } + } + } + + #[inline] + pub fn send(&self, mid: u64, mut permit: Permit, msg: M) -> Result<(), SendError> { let any_receiver = self.inner.typed(); let receiver = any_receiver.dyn_typed_receiver::(); + let res = receiver.send(mid, msg); + permit.fuse = true; - receiver.try_send(msg) + if !res.is_err() { + self.context.need_flush.store(true, Ordering::SeqCst); + } + + res } #[inline] - pub fn close(&self) { - self.inner.close(); + pub fn force_send(&self, mid: u64, msg: M) -> Result<(), SendError> { + let any_receiver = self.inner.typed(); + let receiver = any_receiver.dyn_typed_receiver::(); + let res = receiver.send(mid, msg); + + if !res.is_err() { + self.context.need_flush.store(true, Ordering::SeqCst); + } + + res + } + + pub fn start_polling_events(&self) -> Box Pin + Send>>> + where + M: Message, + E: crate::Error + { + let ctx_clone = self.context.clone(); + let inner_clone = self.inner.clone(); + + Box::new(move |bus| Box::pin(async move { + let any_receiver = inner_clone.poller(); + let receiver = any_receiver.dyn_typed_receiver::(); + + loop { + let event = poll_fn(move |ctx| receiver.poll_events(ctx)) + .await; + + match event { + Event::Exited => { + ctx_clone.closed.notify_waiters(); + break; + }, + + Event::Flushed => ctx_clone.flushed.notify_waiters(), + Event::Synchronized => ctx_clone.synchronized.notify_waiters(), + Event::Response(_mid, resp) => { + ctx_clone.processing.fetch_sub(1, Ordering::SeqCst); + ctx_clone.response.notify_one(); + + match resp { + Ok(_msg) => (), + Err(err) => { bus.try_send(msgs::Error(Arc::new(err.into()))).ok(); } + } + }, + _ => unimplemented!() + } + } + })) + } + + // #[inline] + // pub fn stats(&self) -> ReceiverStats { + // if self.inner.stats().is_ok() { + // self.context.stats.notified() + // .await + // } else { + // warn!("close failed!"); + // } + // } + + #[inline] + pub async fn close(&self) { + if self.inner.close().is_ok() { + self.context.closed.notified() + .await + } else { + warn!("close failed!"); + } } #[inline] - pub fn stats(&self) -> ReceiverStats { - self.inner.stats() + pub async fn sync(&self) { + if self.inner.sync().is_ok() { + self.context.synchronized.notified() + .await + } else { + warn!("sync failed!"); + } } #[inline] - pub fn sync(&self) -> impl Future + '_ { - self.inner.sync(); - - poll_fn(move |ctx| self.inner.poll_synchronized(ctx)) + pub async fn flush(&self) { + if self.inner.flush().is_ok() { + self.context.flushed.notified() + .await; + + self.context.need_flush.store(false, Ordering::SeqCst); + } else { + warn!("flush failed!"); + } } } diff --git a/src/receivers/buffer_unordered/async.rs b/src/receivers/buffer_unordered/async.rs index 155296d..96b2dbe 100644 --- a/src/receivers/buffer_unordered/async.rs +++ b/src/receivers/buffer_unordered/async.rs @@ -9,30 +9,37 @@ use std::{ task::{Context, Poll}, }; -use crate::{receiver::ReceiverStats, receivers::mpsc}; -use futures::{Future, StreamExt}; +use crate::{receiver::{AnyPoller, ReceiverStats, ReciveTypedReceiver}, receivers::{Action, Event}}; +use anyhow::Result; +use futures::{Future, StreamExt, stream::FuturesUnordered}; use super::{BufferUnorderedConfig, BufferUnorderedStats}; use crate::{ builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, - msgs, - receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + receiver::{AnyReceiver, ReceiverTrait, SendError, SendTypedReceiver}, AsyncHandler, Bus, Message, Untyped, }; +use parking_lot::Mutex; +use tokio::sync::mpsc; -pub struct BufferUnorderedAsyncSubscriber + +pub struct BufferUnorderedAsyncSubscriber where - T: AsyncHandler + 'static, + T: AsyncHandler + 'static, M: Message, + R: Message, + E: crate::Error { cfg: BufferUnorderedConfig, _m: PhantomData<(T, M)>, } -impl ReceiverSubscriber for BufferUnorderedAsyncSubscriber +impl ReceiverSubscriber for BufferUnorderedAsyncSubscriber where - T: AsyncHandler + 'static, + T: AsyncHandler + 'static, M: Message, + R: Message, + E: crate::Error { fn subscribe( self, @@ -50,15 +57,17 @@ where parallel_total: AtomicU64::new(cfg.max_parallel as _), }); - let (tx, rx) = mpsc::channel(cfg.buffer_size); - let arc = Arc::new(BufferUnorderedAsync:: { + let (stx, srx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::unbounded_channel(); + let arc = Arc::new(BufferUnorderedAsync:: { tx, stats: stats.clone(), + srx: Mutex::new(srx), }); let poller = Box::new(move |ut| { Box::new(move |bus| { - Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg, stx)) as Pin + Send>> }) as Box Pin + Send>>> }); @@ -67,68 +76,129 @@ where } } -async fn buffer_unordered_poller( - rx: mpsc::Receiver, +fn buffer_unordered_poller( + mut rx: mpsc::UnboundedReceiver>, bus: Bus, ut: Untyped, stats: Arc, cfg: BufferUnorderedConfig, -) where - T: AsyncHandler + 'static, + stx: mpsc::UnboundedSender>, +) -> impl Future +where + T: AsyncHandler + 'static, M: Message, + R: Message, + E: crate::Error { let ut = ut.downcast::().unwrap(); + let mut queue = FuturesUnordered::new(); + let mut sync_future: Option> + Send>>> = None; + let mut need_sync = false; + let mut rx_closed = false; + let mut need_flush = false; - let mut x = rx - .map(|msg| { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.parallel.fetch_add(1, Ordering::Relaxed); - let bus = bus.clone(); - let ut = ut.clone(); + futures::future::poll_fn(move |cx| loop { + if !rx_closed && !need_flush && !need_sync { + while queue.len() < cfg.max_parallel { + match rx.poll_recv(cx) { + Poll::Ready(Some(a)) => { + match a { + Action::Request(mid, msg) => { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); - tokio::task::spawn(async move { ut.handle(msg, &bus).await }) - }) - .buffer_unordered(cfg.max_parallel); - - while let Some(err) = x.next().await { - stats.parallel.fetch_sub(1, Ordering::Relaxed); - - match err { - Ok(Err(err)) => { - let _ = bus.send(msgs::Error(Arc::new(err))).await; + let bus = bus.clone(); + let ut = ut.clone(); + queue.push(tokio::task::spawn(async move { (mid, ut.handle(msg, &bus).await) })); + }, + Action::Flush => need_flush = true, + Action::Sync => need_sync = true, + Action::Close => rx.close(), + _ => unimplemented!() + } + }, + Poll::Ready(None) => { + need_sync = true; + rx_closed = true; + }, + Poll::Pending => break, + } } - _ => (), } - } - let ut = ut.clone(); - let bus_clone = bus.clone(); - let res = tokio::task::spawn(async move { ut.sync(&bus_clone).await }).await; + let queue_len = queue.len(); + + loop { + if queue_len != 0 { + loop { + match queue.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok((mid, res)))) => { + stx.send(Event::Response(mid, res)).ok(); + }, + Poll::Ready(None) => break, + _ => {} + } + } + } - match res { - Ok(Err(err)) => { - let _ = bus.send(msgs::Error(Arc::new(err))).await; + if need_flush { + need_flush = false; + stx.send(Event::Flushed).ok(); + } + + if need_sync { + if let Some(mut fut) = sync_future.take() { + match fut.as_mut().poll(cx) { + Poll::Pending => { + sync_future = Some(fut); + return Poll::Pending; + }, + Poll::Ready(res) => { + need_sync = false; + + if let Err(err) = res { + stx.send(Event::SyncResponse(err)).ok(); + } + } + } + } else { + let ut = ut.clone(); + let bus_clone = bus.clone(); + sync_future.replace(Box::pin(async move { + ut.sync(&bus_clone).await + })); + } + } else { + break; + } } - _ => (), - } - println!( - "[EXIT] BufferUnorderedAsync<{}>", - std::any::type_name::() - ); + if queue_len == queue.len() { + return if rx_closed { Poll::Ready(()) } else { Poll::Pending }; + } + }) } -pub struct BufferUnorderedAsync { - tx: mpsc::Sender, - stats: Arc, -} - -impl ReceiverSubscriberBuilder for BufferUnorderedAsync -where - T: AsyncHandler + 'static, - M: Message, +pub struct BufferUnorderedAsync + where + M: Message, + R: Message, + E: crate::Error { - type Entry = BufferUnorderedAsyncSubscriber; + tx: mpsc::UnboundedSender>, + stats: Arc, + srx: Mutex>>, +} + +impl ReceiverSubscriberBuilder for BufferUnorderedAsync + where + T: AsyncHandler + 'static, + R: Message, + M: Message, + E: crate::Error +{ + type Entry = BufferUnorderedAsyncSubscriber; type Config = BufferUnorderedConfig; fn build(cfg: Self::Config) -> Self::Entry { @@ -139,65 +209,102 @@ where } } -impl TypedReceiver for BufferUnorderedAsync { - fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { - match self.tx.poll_ready(ctx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - } - } - - fn try_send(&self, m: M) -> Result<(), SendError> { - match self.tx.try_send(m) { +impl SendTypedReceiver for BufferUnorderedAsync + where + M: Message, + R: Message, + E: crate::Error +{ + fn send(&self, mid: u64, m: M) -> Result<(), SendError> { + match self.tx.send(Action::Request(mid, m)) { Ok(_) => { self.stats.buffer.fetch_add(1, Ordering::Relaxed); Ok(()) } - Err(err) => Err(err), + Err(mpsc::error::SendError(Action::Request(_, msg))) => Err(SendError::Closed(msg)), + _ => unimplemented!() } } } -impl ReceiverTrait for BufferUnorderedAsync { +impl ReciveTypedReceiver for BufferUnorderedAsync + where + M: Message, + R: Message, + E: crate::Error +{ + fn poll_events(&self, ctx: &mut Context<'_>) -> Poll> { + let poll = self.srx.lock().poll_recv(ctx); + match poll { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(event)) => Poll::Ready(event), + Poll::Ready(None) => Poll::Ready(Event::Exited), + } + } +} + +impl ReceiverTrait for BufferUnorderedAsync + where + M: Message, + R: Message, + E: crate::Error +{ fn typed(&self) -> AnyReceiver<'_> { AnyReceiver::new(self) } + fn poller(&self) -> AnyPoller<'_> { + AnyPoller::new(self) + } + fn type_id(&self) -> TypeId { - TypeId::of::>() + TypeId::of::>() } - fn close(&self) { - self.tx.close(); + fn stats(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Stats) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } + // ReceiverStats { + // name: std::any::type_name::().into(), + // fields: vec![ + // ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + // ( + // "buffer_total".into(), + // self.stats.buffer_total.load(Ordering::SeqCst), + // ), + // ( + // "parallel".into(), + // self.stats.parallel.load(Ordering::SeqCst), + // ), + // ( + // "parallel_total".into(), + // self.stats.parallel_total.load(Ordering::SeqCst), + // ), + // ], + // } } - fn stats(&self) -> ReceiverStats { - ReceiverStats { - name: std::any::type_name::().into(), - fields: vec![ - ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), - ( - "buffer_total".into(), - self.stats.buffer_total.load(Ordering::SeqCst), - ), - ( - "parallel".into(), - self.stats.parallel.load(Ordering::SeqCst), - ), - ( - "parallel_total".into(), - self.stats.parallel_total.load(Ordering::SeqCst), - ), - ], + fn close(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Close) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) } } - fn sync(&self) { - self.tx.flush(); + fn sync(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Sync) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } } - fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { - Poll::Ready(()) + fn flush(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Flush) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } } } diff --git a/src/receivers/buffer_unordered/sync.rs b/src/receivers/buffer_unordered/sync.rs index d8955a9..631c2f1 100644 --- a/src/receivers/buffer_unordered/sync.rs +++ b/src/receivers/buffer_unordered/sync.rs @@ -1,5 +1,3 @@ -use crate::{receiver::ReceiverStats, receivers::mpsc}; -use futures::{Future, StreamExt}; use std::{ any::TypeId, marker::PhantomData, @@ -11,27 +9,37 @@ use std::{ task::{Context, Poll}, }; +use crate::{receiver::{AnyPoller, ReceiverStats, ReciveTypedReceiver}, receivers::{Action, Event}}; +use anyhow::Result; +use futures::{Future, StreamExt, stream::FuturesUnordered}; + use super::{BufferUnorderedConfig, BufferUnorderedStats}; use crate::{ builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, - msgs, - receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, - Bus, Handler, Message, Untyped, + receiver::{AnyReceiver, ReceiverTrait, SendError, SendTypedReceiver}, + Handler, Bus, Message, Untyped, }; +use parking_lot::Mutex; +use tokio::sync::mpsc; -pub struct BufferUnorderedSyncSubscriber + +pub struct BufferUnorderedSyncSubscriber where - T: Handler + 'static, + T: Handler + 'static, M: Message, + R: Message, + E: crate::Error { cfg: BufferUnorderedConfig, - _m: PhantomData<(M, T)>, + _m: PhantomData<(T, M)>, } -impl ReceiverSubscriber for BufferUnorderedSyncSubscriber +impl ReceiverSubscriber for BufferUnorderedSyncSubscriber where - T: Handler + 'static, + T: Handler + 'static, M: Message, + R: Message, + E: crate::Error { fn subscribe( self, @@ -42,7 +50,6 @@ where >, ) { let cfg = self.cfg; - let (tx, rx) = mpsc::channel(cfg.buffer_size); let stats = Arc::new(BufferUnorderedStats { buffer: AtomicU64::new(0), buffer_total: AtomicU64::new(cfg.buffer_size as _), @@ -50,14 +57,17 @@ where parallel_total: AtomicU64::new(cfg.max_parallel as _), }); - let arc = Arc::new(BufferUnorderedSync:: { + let (stx, srx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::unbounded_channel(); + let arc = Arc::new(BufferUnorderedSync:: { tx, stats: stats.clone(), + srx: Mutex::new(srx), }); let poller = Box::new(move |ut| { Box::new(move |bus| { - Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg, stx)) as Pin + Send>> }) as Box Pin + Send>>> }); @@ -66,69 +76,129 @@ where } } -async fn buffer_unordered_poller( - rx: mpsc::Receiver, +fn buffer_unordered_poller( + mut rx: mpsc::UnboundedReceiver>, bus: Bus, ut: Untyped, stats: Arc, cfg: BufferUnorderedConfig, -) where - T: Handler + 'static, + stx: mpsc::UnboundedSender>, +) -> impl Future +where + T: Handler + 'static, M: Message, + R: Message, + E: crate::Error { let ut = ut.downcast::().unwrap(); + let mut queue = FuturesUnordered::new(); + let mut sync_future: Option> + Send>>> = None; + let mut need_sync = false; + let mut rx_closed = false; + let mut need_flush = false; - let mut x = rx - .map(|msg| { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.parallel.fetch_add(1, Ordering::Relaxed); + futures::future::poll_fn(move |cx| loop { + if !rx_closed && !need_flush && !need_sync { + while queue.len() < cfg.max_parallel { + match rx.poll_recv(cx) { + Poll::Ready(Some(a)) => { + match a { + Action::Request(mid, msg) => { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); - let bus = bus.clone(); - let ut = ut.clone(); - - tokio::task::spawn_blocking(move || ut.handle(msg, &bus)) - }) - .buffer_unordered(cfg.max_parallel); - - while let Some(err) = x.next().await { - stats.parallel.fetch_sub(1, Ordering::Relaxed); - - match err { - Ok(Err(err)) => { - let _ = bus.send(msgs::Error(Arc::new(err))).await; + let bus = bus.clone(); + let ut = ut.clone(); + queue.push( tokio::task::spawn_blocking(move || (mid, ut.handle(msg, &bus)))); + }, + Action::Flush => need_flush = true, + Action::Sync => need_sync = true, + Action::Close => rx.close(), + _ => unimplemented!() + } + }, + Poll::Ready(None) => { + need_sync = true; + rx_closed = true; + }, + Poll::Pending => break, + } } - _ => (), } - } - let ut = ut.clone(); - let bus_clone = bus.clone(); - let res = tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await; + let queue_len = queue.len(); + + loop { + if queue_len != 0 { + loop { + match queue.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok((mid, res)))) => { + stx.send(Event::Response(mid, res)).ok(); + }, + Poll::Ready(None) => break, + _ => {} + } + } + } - match res { - Ok(Err(err)) => { - let _ = bus.send(msgs::Error(Arc::new(err))).await; + if need_flush { + need_flush = false; + stx.send(Event::Flushed).ok(); + } + + if need_sync { + if let Some(mut fut) = sync_future.take() { + match fut.as_mut().poll(cx) { + Poll::Pending => { + sync_future = Some(fut); + return Poll::Pending; + }, + Poll::Ready(res) => { + need_sync = false; + + if let Err(err) = res { + stx.send(Event::SyncResponse(err)).ok(); + } + } + } + } else { + let ut = ut.clone(); + let bus_clone = bus.clone(); + sync_future.replace(Box::pin(async move { + tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await.unwrap() + })); + } + } else { + break; + } } - _ => (), - } - println!( - "[EXIT] BufferUnorderedSync<{}>", - std::any::type_name::() - ); + if queue_len == queue.len() { + return if rx_closed { Poll::Ready(()) } else { Poll::Pending }; + } + }) } -pub struct BufferUnorderedSync { - tx: mpsc::Sender, - stats: Arc, -} - -impl ReceiverSubscriberBuilder for BufferUnorderedSync -where - T: Handler + 'static, - M: Message, +pub struct BufferUnorderedSync + where + M: Message, + R: Message, + E: crate::Error { - type Entry = BufferUnorderedSyncSubscriber; + tx: mpsc::UnboundedSender>, + stats: Arc, + srx: Mutex>>, +} + +impl ReceiverSubscriberBuilder for BufferUnorderedSync + where + T: Handler + 'static, + R: Message, + M: Message, + E: crate::Error +{ + type Entry = BufferUnorderedSyncSubscriber; type Config = BufferUnorderedConfig; fn build(cfg: Self::Config) -> Self::Entry { @@ -139,65 +209,102 @@ where } } -impl TypedReceiver for BufferUnorderedSync { - fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { - match self.tx.poll_ready(ctx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - } - } - - fn try_send(&self, m: M) -> Result<(), SendError> { - match self.tx.try_send(m) { +impl SendTypedReceiver for BufferUnorderedSync + where + M: Message, + R: Message, + E: crate::Error +{ + fn send(&self, mid: u64, m: M) -> Result<(), SendError> { + match self.tx.send(Action::Request(mid, m)) { Ok(_) => { self.stats.buffer.fetch_add(1, Ordering::Relaxed); Ok(()) } - Err(err) => Err(err), + Err(mpsc::error::SendError(Action::Request(_, msg))) => Err(SendError::Closed(msg)), + _ => unimplemented!() } } } -impl ReceiverTrait for BufferUnorderedSync { +impl ReciveTypedReceiver for BufferUnorderedSync + where + M: Message, + R: Message, + E: crate::Error +{ + fn poll_events(&self, ctx: &mut Context<'_>) -> Poll> { + let poll = self.srx.lock().poll_recv(ctx); + match poll { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(event)) => Poll::Ready(event), + Poll::Ready(None) => Poll::Ready(Event::Exited), + } + } +} + +impl ReceiverTrait for BufferUnorderedSync + where + M: Message, + R: Message, + E: crate::Error +{ fn typed(&self) -> AnyReceiver<'_> { AnyReceiver::new(self) } - fn type_id(&self) -> TypeId { - TypeId::of::>() + fn poller(&self) -> AnyPoller<'_> { + AnyPoller::new(self) } - fn stats(&self) -> ReceiverStats { - ReceiverStats { - name: std::any::type_name::().into(), - fields: vec![ - ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), - ( - "buffer_total".into(), - self.stats.buffer_total.load(Ordering::SeqCst), - ), - ( - "parallel".into(), - self.stats.parallel.load(Ordering::SeqCst), - ), - ( - "parallel_total".into(), - self.stats.parallel_total.load(Ordering::SeqCst), - ), - ], + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Stats) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } + // ReceiverStats { + // name: std::any::type_name::().into(), + // fields: vec![ + // ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + // ( + // "buffer_total".into(), + // self.stats.buffer_total.load(Ordering::SeqCst), + // ), + // ( + // "parallel".into(), + // self.stats.parallel.load(Ordering::SeqCst), + // ), + // ( + // "parallel_total".into(), + // self.stats.parallel_total.load(Ordering::SeqCst), + // ), + // ], + // } + } + + fn close(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Close) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) } } - fn close(&self) { - self.tx.close(); + fn sync(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Sync) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } } - fn sync(&self) { - self.tx.flush(); - } - - fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { - Poll::Ready(()) + fn flush(&self) -> Result<(), SendError<()>> { + match self.tx.send(Action::Flush) { + Ok(_) => Ok(()), + Err(_) => Err(SendError::Closed(())) + } } } diff --git a/src/receivers/mod.rs b/src/receivers/mod.rs index a8c8c25..e0cde1f 100644 --- a/src/receivers/mod.rs +++ b/src/receivers/mod.rs @@ -1,11 +1,68 @@ mod buffer_unordered; -mod buffer_unordered_batched; -mod mpsc_futures; -mod synchronize_batched; -mod synchronized; +// mod buffer_unordered_batched; +// mod mpsc_futures; +// mod synchronize_batched; +// mod synchronized; -mod mpsc { - pub use super::mpsc_futures::*; + +// mod mpsc; +// mod mpsc { +// pub use super::mpsc_futures::*; +// } + +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct Stats { + pub has_queue: bool, + pub queue_capacity: u64, + pub queue_size: u64, + + pub has_parallel: bool, + pub parallel_capacity: u64, + pub parallel_size: u64, + + pub has_batch: bool, + pub batch_capacity: u64, + pub batch_size: u64, +} + +#[non_exhaustive] +#[derive(Debug)] +pub enum Action { + Request(u64, M), + Flush, + Sync, + Close, + Stats, +} + +#[non_exhaustive] +#[derive(Debug, Clone)] +pub enum Event { + Response(u64, Result), + SyncResponse(E), + Stats(Stats), + Synchronized, + Flushed, + Exited, +} + +pub struct Permit { + pub(crate) fuse: bool, + pub(crate) inner: Arc +} + +pub trait PermitDrop { + fn permit_drop(&self); +} + +impl Drop for Permit { + fn drop(&mut self) { + if !self.fuse { + self.inner.permit_drop(); + } + } } pub use buffer_unordered::{ @@ -13,17 +70,17 @@ pub use buffer_unordered::{ BufferUnorderedSync, BufferUnorderedSyncSubscriber, }; -pub use buffer_unordered_batched::{ - BufferUnorderedBatchedAsync, BufferUnorderedBatchedAsyncSubscriber, BufferUnorderedBatchedConfig, - BufferUnorderedBatchedSync, BufferUnorderedBatchedSyncSubscriber, -}; +// pub use buffer_unordered_batched::{ +// BufferUnorderedBatchedAsync, BufferUnorderedBatchedAsyncSubscriber, BufferUnorderedBatchedConfig, +// BufferUnorderedBatchedSync, BufferUnorderedBatchedSyncSubscriber, +// }; -pub use synchronized::{ - SynchronizedAsync, SynchronizedAsyncSubscriber, SynchronizedConfig, SynchronizedSync, - SynchronizedSyncSubscriber, -}; +// pub use synchronized::{ +// SynchronizedAsync, SynchronizedAsyncSubscriber, SynchronizedConfig, SynchronizedSync, +// SynchronizedSyncSubscriber, +// }; -pub use synchronize_batched::{ - SynchronizeBatchedAsync, SynchronizeBatchedAsyncSubscriber, SynchronizeBatchedConfig, - SynchronizeBatchedSync, SynchronizeBatchedSyncSubscriber, -}; +// pub use synchronize_batched::{ +// SynchronizeBatchedAsync, SynchronizeBatchedAsyncSubscriber, SynchronizeBatchedConfig, +// SynchronizeBatchedSync, SynchronizeBatchedSyncSubscriber, +// }; diff --git a/src/receivers/mpsc.rs b/src/receivers/mpsc.rs deleted file mode 100644 index 6a62d8c..0000000 --- a/src/receivers/mpsc.rs +++ /dev/null @@ -1,142 +0,0 @@ -use futures::{Stream, StreamExt}; -use core::pin::Pin; -use crossbeam::queue::ArrayQueue; -use crossbeam::atomic::AtomicCell; -use core::task::{Waker, Context, Poll}; -use std::sync::{Arc, atomic::*}; - -use crate::receiver::SendError; - - - -struct ChannelInner { - queue: ArrayQueue, - send_waker: AtomicCell>>, - recv_waker: AtomicCell>>, - closed: AtomicBool, -} - -pub fn channel(buffer: usize) -> (Sender, Receiver) { - let inner = Arc::new(ChannelInner { - queue: ArrayQueue::new(buffer), - send_waker: AtomicCell::new(None), - recv_waker: AtomicCell::new(None), - closed: AtomicBool::new(false), - }); - - ( - Sender { - inner: inner.clone(), - }, - Receiver { - inner, - } - ) -} - - -#[derive(Clone)] -pub struct Sender { - inner: Arc> -} - -impl Sender { - pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - if self.inner.closed.load(Ordering::SeqCst) { - return Poll::Ready(()); - } - - if self.inner.queue.is_full() { - self.inner.send_waker.store(Some(Box::new(cx.waker().clone()))); - } - - let mut counter = 4; - loop { - if self.inner.queue.is_full() { - if counter > 0 { - counter -= 1; - continue; - } else { - break Poll::Pending; - } - } else { - break Poll::Ready(()); - } - } - } - - pub fn try_send(&self, mut item: T) -> Result<(), SendError> { - if self.inner.closed.load(Ordering::SeqCst) { - return Err(SendError::Closed(item)); - } - - let mut counter = 0; - loop { - match self.inner.queue.push(item) { - Ok(_) => { - if let Some(waker) = self.inner.recv_waker.take() { - waker.wake(); - } - - break Ok(()); - } - - Err(inner) => { - if counter >= 4 { - break Err(SendError::Full(inner)); - } else { - item = inner; - counter += 1; - } - } - } - } - } - - pub fn close(&self) { - self.inner.closed.store(true, Ordering::SeqCst); - if let Some(waker) = self.inner.recv_waker.take() { - waker.wake(); - } - } -} - -pub struct Receiver { - inner: Arc> -} - -impl Stream for Receiver { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let mut counter = 0; - - loop { - match this.inner.queue.pop() { - Some(inner) => { - if let Some(waker) = this.inner.send_waker.take() { - waker.wake(); - } - - break Poll::Ready(Some(inner)); - }, - None => { - if this.inner.closed.load(Ordering::SeqCst) { - break Poll::Ready(None); - } else { - if counter == 0 { - this.inner.recv_waker.store(Some(Box::new(cx.waker().clone()))); - } - - if counter >= 8 { - break Poll::Pending; - } else { - counter += 1; - } - } - } - } - } - } -} diff --git a/src/receivers/mpsc_futures.rs b/src/receivers/mpsc_futures.rs deleted file mode 100644 index 18bbba9..0000000 --- a/src/receivers/mpsc_futures.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate::receiver::SendError; -use core::pin::Pin; -use core::task::{Context, Poll}; -use crossbeam::queue::SegQueue; -use futures::{channel::mpsc, Stream}; -use std::{ - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::Waker, -}; - -pub struct State { - buffer: usize, - counter: AtomicUsize, - send_wakers: SegQueue, -} - -pub fn channel(buffer: usize) -> (Sender, Receiver) { - let state = Arc::new(State { - buffer, - counter: AtomicUsize::new(0), - send_wakers: SegQueue::new(), - }); - let (tx, rx) = mpsc::unbounded(); - - ( - Sender { - inner: tx, - state: state.clone(), - }, - Receiver { inner: rx, state }, - ) -} - -pub struct Sender { - inner: mpsc::UnboundedSender, - state: Arc, -} - -impl Sender { - pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - if self.state.counter.load(Ordering::SeqCst) >= self.state.buffer { - self.state.send_wakers.push(cx.waker().clone()); - return Poll::Pending; - } - - Poll::Ready(()) - } - - pub fn try_send(&self, item: T) -> Result<(), SendError> { - if self.state.counter.load(Ordering::Relaxed) >= self.state.buffer { - return Err(SendError::Full(item)); - } - - self.state.counter.fetch_add(1, Ordering::SeqCst); - match self.inner.unbounded_send(item) { - Ok(_) => Ok(()), - Err(err) if err.is_full() => Err(SendError::Full(err.into_inner())), - Err(err) => Err(SendError::Closed(err.into_inner())), - } - } - - #[inline] - pub fn flush(&self) {} - - #[inline] - pub fn close(&self) { - self.inner.close_channel(); - } -} - -pub struct Receiver { - inner: mpsc::UnboundedReceiver, - state: Arc, -} - -impl Stream for Receiver { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Ready(inner) => { - let val = this.state.buffer - this.state.counter.fetch_sub(1, Ordering::SeqCst) + 1; - - for _ in 0..val { - if let Some(waker) = this.state.send_wakers.pop() { - waker.wake(); - } else { - break; - } - } - - Poll::Ready(inner) - } - Poll::Pending => { - while let Some(waker) = this.state.send_wakers.pop() { - waker.wake(); - } - - Poll::Pending - }, - } - } -} diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 1e3ab7d..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,33 +0,0 @@ -use core::cmp::{Ord, Ordering}; -use core::ops::Range; - -pub fn binary_search_range_by_key<'a, T, B, F>(data: &'a [T], item: &B, mut f: F) -> Range -where - F: FnMut(&'a T) -> B, - B: Ord, -{ - if let Ok(index) = data.binary_search_by_key(item, &mut f) { - let mut begin = index; - let mut end = index + 1; - - for i in (0..index).rev() { - if f(unsafe { data.get_unchecked(i) }).cmp(item) != Ordering::Equal { - break; - } - - begin = i; - } - - for i in end..data.len() { - end = i; - - if f(unsafe { data.get_unchecked(i) }).cmp(item) != Ordering::Equal { - break; - } - } - - begin..end - } else { - data.len()..data.len() - } -}