From db1d22df7f511ce504ee042092877723e46aad28 Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Tue, 22 Jun 2021 18:43:06 +0400 Subject: [PATCH] Request/Response --- Cargo.toml | 5 +- examples/demo.rs | 79 ------- examples/demo_async.rs | 6 +- examples/demo_backpressure.rs | 13 +- examples/demo_req_resp.rs | 172 +++++++++++++++ src/builder.rs | 36 ++-- src/handler.rs | 1 - src/lib.rs | 129 +++++++---- src/receiver.rs | 273 +++++++++++++++--------- src/receivers/buffer_unordered/async.rs | 143 +++++++------ src/receivers/buffer_unordered/sync.rs | 145 +++++++------ src/receivers/mod.rs | 9 +- 12 files changed, 639 insertions(+), 372 deletions(-) delete mode 100644 examples/demo.rs create mode 100644 examples/demo_req_resp.rs diff --git a/Cargo.toml b/Cargo.toml index 1fdc18b..4448f39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,17 +11,16 @@ exclude = [".gitignore", ".cargo/config", ".github/**", "codecov.yml"] edition = "2018" [dependencies] -tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "io-util", "sync"] } +tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "sync"] } parking_lot = "0.11.1" async-trait = "0.1.42" futures = "0.3.8" anyhow = "1.0.34" -crossbeam = "0.8.1" -uuid = "0.8.2" tokio-util = "0.6.7" async-stream = "0.3.2" smallvec = "1.6.1" log = "0.4.14" +sharded-slab = "0.1.1" [dev-dependencies] tokio = { version = "1", features = ["macros", "parking_lot", "rt-multi-thread", "io-util", "sync"] } diff --git a/examples/demo.rs b/examples/demo.rs deleted file mode 100644 index c83d523..0000000 --- a/examples/demo.rs +++ /dev/null @@ -1,79 +0,0 @@ -use messagebus::{receivers, Bus, Handler, Result as MbusResult}; - -struct TmpReceiver; -struct TmpReceiver2; - -impl Handler for TmpReceiver { - fn handle(&self, msg: f32, bus: &Bus) -> MbusResult { - bus.try_send(1u16).unwrap(); - - println!("---> f32 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver { - fn handle(&self, msg: u16, bus: &Bus) -> MbusResult { - bus.try_send(1u32).unwrap(); - println!("---> u16 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver { - fn handle(&self, msg: u32, bus: &Bus) -> MbusResult { - bus.try_send(2i32).unwrap(); - println!("---> u32 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver { - fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { - bus.try_send(3i16).unwrap(); - println!("---> i32 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver { - fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { - println!("---> i16 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver2 { - fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { - bus.try_send(3i16).unwrap(); - println!("---> 2 i32 {}", msg); - Ok(()) - } -} - -impl Handler for TmpReceiver2 { - fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { - println!("---> 2 i16 {}", msg); - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let (b, poller) = Bus::build() - .register(TmpReceiver) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .done() - .register(TmpReceiver2) - .subscribe::>(Default::default()) - .subscribe::>(Default::default()) - .done() - .build(); - - b.send(32f32).await.unwrap(); - poller.await -} diff --git a/examples/demo_async.rs b/examples/demo_async.rs index dcbd72c..a480210 100644 --- a/examples/demo_async.rs +++ b/examples/demo_async.rs @@ -144,10 +144,10 @@ async fn main() { .subscribe::, _, _>(8, Default::default()) .done() .register(TmpReceiver2) - .subscribe::, _, _>(8, Default::default()) - .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) .done() - .build(); + .build(); b.send(0f32).await.unwrap(); diff --git a/examples/demo_backpressure.rs b/examples/demo_backpressure.rs index 7c21799..7b041e0 100644 --- a/examples/demo_backpressure.rs +++ b/examples/demo_backpressure.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use messagebus::{receivers, Bus, AsyncHandler}; +use messagebus::{receivers, AsyncHandler, Bus}; struct TmpReceiver; @@ -22,10 +22,13 @@ impl AsyncHandler for TmpReceiver { async fn main() { let (b, poller) = Bus::build() .register(TmpReceiver) - .subscribe::, _, _>(1, receivers::BufferUnorderedConfig { - buffer_size: 1, - max_parallel: 1, - }) + .subscribe::, _, _>( + 1, + receivers::BufferUnorderedConfig { + buffer_size: 1, + max_parallel: 1, + }, + ) .done() .build(); diff --git a/examples/demo_req_resp.rs b/examples/demo_req_resp.rs new file mode 100644 index 0000000..8cabfed --- /dev/null +++ b/examples/demo_req_resp.rs @@ -0,0 +1,172 @@ +use async_trait::async_trait; +use messagebus::{receivers, AsyncHandler, Bus}; + +struct TmpReceiver1; +struct TmpReceiver2; + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: i32, bus: &Bus) -> Result { + let resp1 = bus.request::<_, f32>(10i16, Default::default()).await?; + let resp2 = bus.request::<_, f32>(20u16, Default::default()).await?; + + Ok(msg as f32 + resp1 + resp2) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver1 i32: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: u32, _bus: &Bus) -> Result { + Ok(msg as f32) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver1 u32: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: i16, bus: &Bus) -> Result { + let resp1 = bus.request::<_, f32>(1i8, Default::default()).await?; + let resp2 = bus.request::<_, f32>(2u8, Default::default()).await?; + + Ok(msg as f32 + resp1 + resp2) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver i16: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: u16, _bus: &Bus) -> Result { + Ok(msg as f32) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver i16: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: i8, _bus: &Bus) -> Result { + Ok(msg as f32) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver1 i8: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver1 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: u8, _bus: &Bus) -> Result { + Ok(msg as f32) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver1 u8: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver2 { + type Error = anyhow::Error; + type Response = f64; + + async fn handle(&self, msg: f64, bus: &Bus) -> Result { + let resp1 = bus.request::<_, f32>(100i32, Default::default()).await? as f64; + let resp2 = bus.request::<_, f32>(200u32, Default::default()).await? as f64; + let resp3 = bus.request::<_, f32>(300f32, Default::default()).await? as f64; + + Ok(msg + resp1 + resp2 + resp3) + } + + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver1 f64: sync"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver2 { + type Error = anyhow::Error; + type Response = f32; + + async fn handle(&self, msg: f32, _bus: &Bus) -> Result { + Ok(msg) + } + async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> { + println!("TmpReceiver2: f32: sync"); + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver1) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .done() + .register(TmpReceiver2) + .subscribe::, _, _>(8, Default::default()) + .subscribe::, _, _>(8, Default::default()) + .done() + .build(); + + println!( + "{:?}", + b.request::<_, f64>(1000f64, Default::default()).await + ); + + println!("flush"); + b.flush().await; + + println!("close"); + b.close().await; + + poller.await; + println!("[done]"); +} diff --git a/src/builder.rs b/src/builder.rs index a15d1c5..0af7a5b 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -3,19 +3,31 @@ use std::{any::TypeId, collections::HashMap, marker::PhantomData, pin::Pin, sync use futures::{Future, FutureExt}; use tokio::sync::Mutex; -use crate::{Bus, BusInner, Message, Untyped, receiver::{Receiver, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver}}; +use crate::{ + receiver::{Receiver, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver}, + Bus, BusInner, Message, Untyped, +}; - -pub trait ReceiverSubscriberBuilder: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver - where - T: 'static, - M: Message, - R: Message, - E: crate::Error +pub trait ReceiverSubscriberBuilder: + SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver +where + T: 'static, + M: Message, + R: Message, + E: crate::Error, { type Config: Default; - fn build(cfg: Self::Config) -> (Self, Box Box Pin + Send>>>>) where Self: Sized; + fn build( + cfg: Self::Config, + ) -> ( + Self, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) + where + Self: Sized; } pub struct SyncEntry; @@ -35,7 +47,7 @@ pub struct RegisterEntry { ) -> Box Pin + Send>>>, >, - Box Pin + Send>>> + Box Pin + Send>>>, )>, >, _m: PhantomData<(K, T)>, @@ -68,7 +80,7 @@ impl RegisterEntry { { let (inner, poller) = S::build(cfg); - let receiver = Receiver::new(queue, inner); + let receiver = Receiver::new::(queue, inner); let poller2 = receiver.start_polling_events::(); self.receivers .entry(TypeId::of::()) @@ -90,7 +102,7 @@ impl RegisterEntry { { let (inner, poller) = S::build(cfg); - let receiver = Receiver::new(queue, inner); + let receiver = Receiver::new::(queue, inner); let poller2 = receiver.start_polling_events::(); self.receivers .entry(TypeId::of::()) diff --git a/src/handler.rs b/src/handler.rs index 26b15ec..09d1353 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -106,7 +106,6 @@ pub trait LocalAsyncHandler { } } - pub trait LocalBatchHandler { type Error: crate::Error; type Response: Message; diff --git a/src/lib.rs b/src/lib.rs index a1b241a..96009a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,24 +9,30 @@ mod trait_object; #[macro_use] extern crate log; +use anyhow::bail; use builder::BusBuilder; pub use envelop::Message; -use futures::{Future, FutureExt, future::poll_fn}; pub use handler::*; pub use receiver::SendError; use receiver::{Receiver, ReceiverStats}; use smallvec::SmallVec; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; -use core::any::{Any, TypeId}; -use std::{collections::HashMap, sync::{Arc, atomic::{AtomicBool, AtomicU64, Ordering}}}; use crate::receiver::Permit; +use core::any::{Any, TypeId}; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }, +}; pub type Untyped = Arc; // pub trait ErrorTrait: std::error::Error + Send + Sync + 'static {} pub trait Error: Into + Send + Sync + 'static {} -impl + Send + Sync + 'static> Error for T {} +impl + Send + Sync + 'static> Error for T {} static ID_COUNTER: AtomicU64 = AtomicU64::new(1); @@ -54,7 +60,8 @@ impl BusInner { let mut receivers = HashMap::new(); for (key, value) in input { - receivers.entry(key) + receivers + .entry(key) .or_insert_with(SmallVec::new) .push(value); } @@ -83,10 +90,9 @@ impl BusInner { iters += 1; let mut flushed = false; for (_, rs) in &self.receivers { - for r in rs { if r.need_flush() { - flushed = true; + flushed = true; r.flush().await; } } @@ -99,11 +105,14 @@ impl BusInner { } if !breaked { - warn!("!!! WARNING: unable to reach equilibrium in {} iterations !!!", fuse_count); + warn!( + "!!! WARNING: unable to reach equilibrium in {} iterations !!!", + fuse_count + ); } else { info!("flushed in {} iterations !!!", iters); } - } + } pub async fn flash_and_sync(&self) { self.flush().await; @@ -122,7 +131,7 @@ impl BusInner { // .map(|r| r.stats()) // } - fn try_reserve(&self, rs: &[Receiver]) -> Option> { + fn try_reserve(&self, rs: &[Receiver]) -> Option> { let mut permits = SmallVec::<[Permit; 32]>::new(); for r in rs { @@ -135,13 +144,17 @@ impl BusInner { Some(permits) } - + #[inline] pub fn try_send(&self, msg: M) -> core::result::Result<(), SendError> { self.try_send_ext(msg, SendOptions::Broadcast) } - pub fn try_send_ext(&self, msg: M, _options: SendOptions) -> core::result::Result<(), SendError> { + pub fn try_send_ext( + &self, + msg: M, + _options: SendOptions, + ) -> core::result::Result<(), SendError> { if self.closed.load(Ordering::SeqCst) { warn!("Bus closed. Skipping send!"); return Ok(()); @@ -174,7 +187,10 @@ impl BusInner { } } - warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); + warn!( + "Unhandled message {:?}: no receivers", + core::any::type_name::() + ); Ok(()) } @@ -185,16 +201,24 @@ impl BusInner { } #[inline] - pub fn send_blocking_ext(&self, msg: M, options: SendOptions) -> core::result::Result<(), SendError> { + pub fn send_blocking_ext( + &self, + msg: M, + options: SendOptions, + ) -> core::result::Result<(), SendError> { futures::executor::block_on(self.send_ext(msg, options)) } #[inline] - pub async fn send(&self, msg: M, ) -> core::result::Result<(), SendError> { + pub async fn send(&self, msg: M) -> core::result::Result<(), SendError> { self.send_ext(msg, SendOptions::Broadcast).await } - pub async fn send_ext(&self, msg: M, _options: SendOptions) -> core::result::Result<(), SendError> { + pub async fn send_ext( + &self, + msg: M, + _options: SendOptions, + ) -> core::result::Result<(), SendError> { if self.closed.load(Ordering::SeqCst) { return Err(SendError::Closed(msg)); } @@ -214,7 +238,10 @@ impl BusInner { } } - warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); + warn!( + "Unhandled message {:?}: no receivers", + core::any::type_name::() + ); Ok(()) } @@ -224,7 +251,11 @@ impl BusInner { self.force_send_ext(msg, SendOptions::Broadcast) } - pub fn force_send_ext(&self, msg: M, _options: SendOptions) -> core::result::Result<(), SendError> { + pub fn force_send_ext( + &self, + msg: M, + _options: SendOptions, + ) -> core::result::Result<(), SendError> { if self.closed.load(Ordering::SeqCst) { return Err(SendError::Closed(msg)); } @@ -244,32 +275,56 @@ impl BusInner { } } - warn!("Unhandled message {:?}: no receivers", core::any::type_name::()); + warn!( + "Unhandled message {:?}: no receivers", + core::any::type_name::() + ); Ok(()) } - // pub fn request(&self, req: M, options: SendOptions) -> impl Future> { - // let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - // let tid = TypeId::of::(); - // let rid = TypeId::of::(); + pub async fn request( + &self, + req: M, + options: SendOptions, + ) -> anyhow::Result { + let tid = TypeId::of::(); + let rid = TypeId::of::(); - // let mut iter = self.select_receivers(options, Some(rid)); - // let first = iter.next(); + let mut iter = self.select_receivers(tid, options, Some(rid)); + if let Some(rc) = iter.next() { + let (tx, rx) = oneshot::channel(); + let mid = (rc.add_response_waiter(tx).unwrap() | 1 << (usize::BITS - 1)) as u64; + rc.send(mid, rc.reserve().await, req)?; - // for rs in iter { - // let _ = rs.send(mid, rs.reserve().await, req.clone()); - // } - - // first.send(mid, first.reserve().await, req); + Ok(rx.await?) + } else { + bail!("No Receivers!"); + } + } - // let (tx, rx) = tokio::sync::oneshot::channel(); - // self.response_waiters.insert(mid, tx); + #[inline] + fn select_receivers( + &self, + tid: TypeId, + _options: SendOptions, + rid: Option, + ) -> impl Iterator + '_ { + self.receivers + .get(&tid) + .into_iter() + .map(|item| item.iter()) + .flatten() + .filter(move |x| { + let ret_ty = if let Some(rid) = rid { + x.resp_type_id() == rid + } else { + true + }; - // poll_fn(move |cx| { - // rx.poll_unpin(cx) - // }) - // } + ret_ty + }) + } } #[derive(Clone)] diff --git a/src/receiver.rs b/src/receiver.rs index 1b3f072..27191a3 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -1,10 +1,30 @@ -use crate::{Bus, Error, Message, msgs, trait_object::TraitObject}; -use core::{any::TypeId, fmt, marker::PhantomData, mem, pin::Pin, task::{Context, Poll}}; +use crate::{msgs, trait_object::TraitObject, Bus, Error, Message}; +use core::{ + any::TypeId, + fmt, + marker::PhantomData, + mem, + pin::Pin, + task::{Context, Poll}, +}; use futures::future::poll_fn; -use tokio::sync::Notify; -use std::{borrow::Cow, sync::{Arc, atomic::{AtomicBool, AtomicU64, Ordering}}}; use futures::Future; +use std::{ + any::Any, + borrow::Cow, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }, +}; +use tokio::sync::{oneshot, Notify}; +struct SlabCfg; +impl sharded_slab::Config for SlabCfg { + const RESERVED_BITS: usize = 1; +} + +type Slab = sharded_slab::Slab; pub trait SendUntypedReceiver: Send + Sync { fn send(&self, msg: Action) -> Result<(), SendError>; @@ -14,9 +34,10 @@ pub trait SendTypedReceiver: Sync { fn send(&self, mid: u64, msg: M) -> Result<(), SendError>; } -pub trait ReciveTypedReceiver: Sync - where M: Message, - E: crate::Error +pub trait ReciveTypedReceiver: Sync +where + M: Message, + E: crate::Error, { fn poll_events(&self, ctx: &mut Context<'_>) -> Poll>; } @@ -24,7 +45,6 @@ pub trait ReciveTypedReceiver: Sync pub trait ReceiverTrait: Send + Sync { fn typed(&self) -> AnyReceiver<'_>; fn poller(&self) -> AnyPoller<'_>; - fn type_id(&self) -> TypeId; fn stats(&self) -> Result<(), SendError<()>>; fn close(&self) -> Result<(), SendError<()>>; fn sync(&self) -> Result<(), SendError<()>>; @@ -39,7 +59,6 @@ pub trait PermitDrop { fn permit_drop(&self); } - #[derive(Debug, Clone)] pub struct Stats { pub has_queue: bool, @@ -75,20 +94,22 @@ pub enum Event { } struct ReceiverWrapper - where M: Message, - R: Message, - E: Error, - S: 'static -{ - inner: S, - _m: PhantomData<(M, R, E)> +where + M: Message, + R: Message, + E: Error, + S: 'static, +{ + inner: S, + _m: PhantomData<(M, R, E)>, } -impl ReceiverTrait for ReceiverWrapper - where M: Message, - R: Message, - E: Error, - S: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver + 'static +impl ReceiverTrait for ReceiverWrapper +where + M: Message, + R: Message, + E: Error, + S: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver + 'static, { fn typed(&self) -> AnyReceiver<'_> { AnyReceiver::new(&self.inner) @@ -98,30 +119,26 @@ impl ReceiverTrait for ReceiverWrapper AnyPoller::new(&self.inner) } - fn type_id(&self) -> TypeId { - TypeId::of::() - } - fn stats(&self) -> Result<(), SendError<()>> { - SendUntypedReceiver::send(&self.inner, Action::Stats).map_err(|_|SendError::Closed(())) + SendUntypedReceiver::send(&self.inner, Action::Stats).map_err(|_| SendError::Closed(())) } fn close(&self) -> Result<(), SendError<()>> { - SendUntypedReceiver::send(&self.inner, Action::Close).map_err(|_|SendError::Closed(())) + SendUntypedReceiver::send(&self.inner, Action::Close).map_err(|_| SendError::Closed(())) } fn sync(&self) -> Result<(), SendError<()>> { - SendUntypedReceiver::send(&self.inner, Action::Sync).map_err(|_|SendError::Closed(())) + SendUntypedReceiver::send(&self.inner, Action::Sync).map_err(|_| SendError::Closed(())) } fn flush(&self) -> Result<(), SendError<()>> { - SendUntypedReceiver::send(&self.inner, Action::Flush).map_err(|_|SendError::Closed(())) + SendUntypedReceiver::send(&self.inner, Action::Flush).map_err(|_| SendError::Closed(())) } } pub struct Permit { pub(crate) fuse: bool, - pub(crate) inner: Arc + pub(crate) inner: Arc, } impl Drop for Permit { @@ -167,11 +184,11 @@ pub struct AnyPoller<'a> { unsafe impl Send for AnyPoller<'_> {} impl<'a> AnyPoller<'a> { - pub fn new(rcvr: &'a R) -> Self - where - M: Message, - E: crate::Error, - R: ReciveTypedReceiver + 'static + pub fn new(rcvr: &'a R) -> Self + where + M: Message, + E: crate::Error, + R: ReciveTypedReceiver + 'static, { let trcvr = rcvr as &(dyn ReciveTypedReceiver); @@ -182,7 +199,9 @@ impl<'a> AnyPoller<'a> { } } - pub fn dyn_typed_receiver(&'a self) -> &'a dyn ReciveTypedReceiver { + pub fn dyn_typed_receiver( + &'a self, + ) -> &'a dyn ReciveTypedReceiver { assert_eq!(self.type_id, TypeId::of::>()); unsafe { mem::transmute(self.dyn_typed_receiver_trait_object) } @@ -242,6 +261,7 @@ impl fmt::Display for ReceiverStats { } struct ReceiverContext { + resp_type_id: TypeId, limit: u64, processing: AtomicU64, need_flush: AtomicBool, @@ -261,6 +281,7 @@ impl PermitDrop for ReceiverContext { pub struct Receiver { inner: Arc, context: Arc, + waiters: Arc, } impl fmt::Debug for Receiver { @@ -280,32 +301,38 @@ impl core::cmp::Eq for Receiver {} impl Receiver { #[inline] - pub(crate) fn new(limit: u64, inner: S) -> Self - where M: Message, - R: Message, - E: Error, - S: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver + 'static + pub(crate) fn new(limit: u64, inner: S) -> Self + where + M: Message, + R: Message, + E: Error, + S: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver + 'static, { - let context = Arc::new(ReceiverContext { - limit, - processing: AtomicU64::new(0), - need_flush: AtomicBool::new(false), - flushed: Notify::new(), - synchronized: Notify::new(), - closed: Notify::new(), - response: Notify::new(), - statistics: Notify::new(), - }); - - Self { inner: Arc::new(ReceiverWrapper{ - inner, - _m: Default::default() - }), context } + Self { + context: Arc::new(ReceiverContext { + resp_type_id: TypeId::of::(), + limit, + processing: AtomicU64::new(0), + need_flush: AtomicBool::new(false), + flushed: Notify::new(), + synchronized: Notify::new(), + closed: Notify::new(), + response: Notify::new(), + statistics: Notify::new(), + }), + inner: Arc::new(ReceiverWrapper { + inner, + _m: Default::default(), + }), + waiters: Arc::new(sharded_slab::Slab::>::new_with_config::< + SlabCfg, + >()), + } } #[inline] - pub fn type_id(&self) -> TypeId { - self.inner.type_id() + pub fn resp_type_id(&self) -> TypeId { + self.context.resp_type_id } #[inline] @@ -313,12 +340,16 @@ impl Receiver { self.context.need_flush.load(Ordering::SeqCst) } - #[inline] pub async fn reserve(&self) -> Permit { loop { let count = self.context.processing.load(Ordering::Relaxed); if count < self.context.limit { - let res = self.context.processing.compare_exchange(count, count + 1, Ordering::SeqCst, Ordering::SeqCst); + let res = self.context.processing.compare_exchange( + count, + count + 1, + Ordering::SeqCst, + Ordering::SeqCst, + ); if res.is_ok() { break Permit { fuse: false, @@ -328,19 +359,22 @@ impl Receiver { // continue } else { - self.context.response.notified() - .await + self.context.response.notified().await } } } - #[inline] pub fn try_reserve(&self) -> Option { loop { let count = self.context.processing.load(Ordering::Relaxed); - + if count < self.context.limit { - let res = self.context.processing.compare_exchange(count, count + 1, Ordering::SeqCst, Ordering::SeqCst); + let res = self.context.processing.compare_exchange( + count, + count + 1, + Ordering::SeqCst, + Ordering::SeqCst, + ); if res.is_ok() { break Some(Permit { fuse: false, @@ -356,7 +390,12 @@ impl Receiver { } #[inline] - pub fn send(&self, mid: u64, mut permit: Permit, msg: M) -> Result<(), SendError> { + pub fn send( + &self, + mid: u64, + mut permit: Permit, + msg: M, + ) -> Result<(), SendError> { let any_receiver = self.inner.typed(); let receiver = any_receiver.dyn_typed_receiver::(); let res = receiver.send(mid, msg); @@ -374,6 +413,7 @@ impl Receiver { let any_receiver = self.inner.typed(); let receiver = any_receiver.dyn_typed_receiver::(); let res = receiver.send(mid, msg); + self.context.processing.fetch_add(1, Ordering::SeqCst); if !res.is_err() { self.context.need_flush.store(true, Ordering::SeqCst); @@ -382,43 +422,75 @@ impl Receiver { res } - pub fn start_polling_events(&self) -> Box Pin + Send>>> - where - M: Message, - E: crate::Error + pub fn start_polling_events( + &self, + ) -> Box Pin + Send>>> + where + R: Message, + E: crate::Error, { let ctx_clone = self.context.clone(); let inner_clone = self.inner.clone(); + let waiters = self + .waiters + .clone() + .downcast::>>() + .unwrap(); - Box::new(move |bus| Box::pin(async move { - let any_receiver = inner_clone.poller(); - let receiver = any_receiver.dyn_typed_receiver::(); + Box::new(move |bus| { + Box::pin(async move { + let any_receiver = inner_clone.poller(); + let receiver = any_receiver.dyn_typed_receiver::(); - loop { - let event = poll_fn(move |ctx| receiver.poll_events(ctx)) - .await; + loop { + let event = poll_fn(move |ctx| receiver.poll_events(ctx)).await; - match event { - Event::Exited => { - ctx_clone.closed.notify_waiters(); - break; - }, - - Event::Flushed => ctx_clone.flushed.notify_waiters(), - Event::Synchronized(_res) => ctx_clone.synchronized.notify_waiters(), - Event::Response(_mid, resp) => { - ctx_clone.processing.fetch_sub(1, Ordering::SeqCst); - ctx_clone.response.notify_one(); - - match resp { - Ok(_msg) => (), - Err(err) => { bus.try_send(msgs::Error(Arc::new(err.into()))).ok(); } + match event { + Event::Exited => { + ctx_clone.closed.notify_waiters(); + break; } - }, - _ => unimplemented!() + + Event::Flushed => ctx_clone.flushed.notify_waiters(), + Event::Synchronized(_res) => ctx_clone.synchronized.notify_waiters(), + Event::Response(mid, resp) => { + ctx_clone.processing.fetch_sub(1, Ordering::SeqCst); + ctx_clone.response.notify_one(); + + match resp { + Ok(msg) => { + if let Some(waiter) = waiters.take(mid as usize) { + if let Err(_msg) = waiter.send(msg) { + error!("Response cannot be processed!"); + } + } else if TypeId::of::() != TypeId::of::<()>() { + warn!("Non-void response has no listeners!"); + } + } + Err(err) => { + bus.try_send(msgs::Error(Arc::new(err.into()))).ok(); + } + } + } + _ => unimplemented!(), + } } - } - })) + }) + }) + } + + #[inline] + pub(crate) fn add_response_waiter( + &self, + waiter: oneshot::Sender, + ) -> Option { + let idx = self + .waiters + .downcast_ref::>>() + .unwrap() + .insert(waiter)?; + + Some(idx) } // #[inline] @@ -434,8 +506,7 @@ impl Receiver { #[inline] pub async fn close(&self) { if self.inner.close().is_ok() { - self.context.closed.notified() - .await + self.context.closed.notified().await } else { warn!("close failed!"); } @@ -444,8 +515,7 @@ impl Receiver { #[inline] pub async fn sync(&self) { if self.inner.sync().is_ok() { - self.context.synchronized.notified() - .await + self.context.synchronized.notified().await } else { warn!("sync failed!"); } @@ -454,9 +524,8 @@ impl Receiver { #[inline] pub async fn flush(&self) { if self.inner.flush().is_ok() { - self.context.flushed.notified() - .await; - + self.context.flushed.notified().await; + self.context.need_flush.store(false, Ordering::SeqCst); } else { warn!("flush failed!"); diff --git a/src/receivers/buffer_unordered/async.rs b/src/receivers/buffer_unordered/async.rs index e9571ed..018c498 100644 --- a/src/receivers/buffer_unordered/async.rs +++ b/src/receivers/buffer_unordered/async.rs @@ -7,9 +7,12 @@ use std::{ task::{Context, Poll}, }; -use crate::{receiver::{Action, Event, ReceiverStats, ReciveTypedReceiver, SendUntypedReceiver}, receivers::Request}; +use crate::{ + receiver::{Action, Event, ReceiverStats, ReciveTypedReceiver, SendUntypedReceiver}, + receivers::Request, +}; use anyhow::Result; -use futures::{Future, StreamExt, stream::FuturesUnordered}; +use futures::{stream::FuturesUnordered, Future, StreamExt}; use super::{BufferUnorderedConfig, BufferUnorderedStats}; use crate::{ @@ -20,7 +23,6 @@ use crate::{ use parking_lot::Mutex; use tokio::sync::mpsc; - fn buffer_unordered_poller( mut rx: mpsc::UnboundedReceiver>, bus: Bus, @@ -33,7 +35,7 @@ where T: AsyncHandler + 'static, M: Message, R: Message, - E: crate::Error + E: crate::Error, { let ut = ut.downcast::().unwrap(); let mut queue = FuturesUnordered::new(); @@ -46,33 +48,33 @@ where if !rx_closed && !need_flush && !need_sync { while queue.len() < cfg.max_parallel { match rx.poll_recv(cx) { - Poll::Ready(Some(a)) => { - match a { - Request::Request(mid, msg) => { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.parallel.fetch_add(1, Ordering::Relaxed); + Poll::Ready(Some(a)) => match a { + Request::Request(mid, msg) => { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); - let bus = bus.clone(); - let ut = ut.clone(); - queue.push(tokio::task::spawn(async move { (mid, ut.handle(msg, &bus).await) })); - }, - Request::Action(Action::Flush) => need_flush = true, - Request::Action(Action::Sync) => need_sync = true, - Request::Action(Action::Close) => rx.close(), - _ => unimplemented!() + let bus = bus.clone(); + let ut = ut.clone(); + queue.push(tokio::task::spawn(async move { + (mid, ut.handle(msg, &bus).await) + })); } + Request::Action(Action::Flush) => need_flush = true, + Request::Action(Action::Sync) => need_sync = true, + Request::Action(Action::Close) => rx.close(), + _ => unimplemented!(), }, Poll::Ready(None) => { need_sync = true; rx_closed = true; - }, + } Poll::Pending => break, } } } let queue_len = queue.len(); - + loop { if queue_len != 0 { loop { @@ -80,7 +82,7 @@ where Poll::Pending => return Poll::Pending, Poll::Ready(Some(Ok((mid, res)))) => { stx.send(Event::Response(mid, res)).ok(); - }, + } Poll::Ready(None) => break, _ => {} } @@ -98,7 +100,7 @@ where Poll::Pending => { sync_future = Some(fut); return Poll::Pending; - }, + } Poll::Ready(res) => { need_sync = false; stx.send(Event::Synchronized(res)).ok(); @@ -107,9 +109,7 @@ where } else { let ut = ut.clone(); let bus_clone = bus.clone(); - sync_future.replace(Box::pin(async move { - ut.sync(&bus_clone).await - })); + sync_future.replace(Box::pin(async move { ut.sync(&bus_clone).await })); } } else { break; @@ -117,16 +117,20 @@ where } if queue_len == queue.len() { - return if rx_closed { Poll::Ready(()) } else { Poll::Pending }; + return if rx_closed { + Poll::Ready(()) + } else { + Poll::Pending + }; } }) } -pub struct BufferUnorderedAsync - where - M: Message, - R: Message, - E: crate::Error +pub struct BufferUnorderedAsync +where + M: Message, + R: Message, + E: crate::Error, { tx: mpsc::UnboundedSender>, stats: Arc, @@ -134,16 +138,22 @@ pub struct BufferUnorderedAsync } impl ReceiverSubscriberBuilder for BufferUnorderedAsync - where - T: AsyncHandler + 'static, - R: Message, - M: Message, - E: crate::Error +where + T: AsyncHandler + 'static, + R: Message, + M: Message, + E: crate::Error, { type Config = BufferUnorderedConfig; - fn build(cfg: Self::Config) -> (Self, Box Box Pin + Send>>>>) - { + fn build( + cfg: Self::Config, + ) -> ( + Self, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { let stats = Arc::new(BufferUnorderedStats { buffer: AtomicU64::new(0), buffer_total: AtomicU64::new(cfg.buffer_size as _), @@ -157,39 +167,48 @@ impl ReceiverSubscriberBuilder for BufferUnorderedAsync< let poller = Box::new(move |ut| { Box::new(move |bus| { - Box::pin(buffer_unordered_poller::(rx, bus, ut, stats_clone, cfg, stx)) - as Pin + Send>> + Box::pin(buffer_unordered_poller::( + rx, + bus, + ut, + stats_clone, + cfg, + stx, + )) as Pin + Send>> }) as Box Pin + Send>>> }); - (BufferUnorderedAsync:: { - tx, - stats, - srx: Mutex::new(srx), - }, poller) + ( + BufferUnorderedAsync:: { + tx, + stats, + srx: Mutex::new(srx), + }, + poller, + ) } } -impl SendUntypedReceiver for BufferUnorderedAsync - where - M: Message, - R: Message, - E: crate::Error +impl SendUntypedReceiver for BufferUnorderedAsync +where + M: Message, + R: Message, + E: crate::Error, { fn send(&self, m: Action) -> Result<(), SendError> { match self.tx.send(Request::Action(m)) { Ok(_) => Ok(()), Err(mpsc::error::SendError(Request::Action(msg))) => Err(SendError::Closed(msg)), - _ => unimplemented!() + _ => unimplemented!(), } } } -impl SendTypedReceiver for BufferUnorderedAsync - where - M: Message, - R: Message, - E: crate::Error +impl SendTypedReceiver for BufferUnorderedAsync +where + M: Message, + R: Message, + E: crate::Error, { fn send(&self, mid: u64, m: M) -> Result<(), SendError> { match self.tx.send(Request::Request(mid, m)) { @@ -199,16 +218,16 @@ impl SendTypedReceiver for BufferUnorderedAsync Ok(()) } Err(mpsc::error::SendError(Request::Request(_, msg))) => Err(SendError::Closed(msg)), - _ => unimplemented!() + _ => unimplemented!(), } } } -impl ReciveTypedReceiver for BufferUnorderedAsync - where - M: Message, - R: Message, - E: crate::Error +impl ReciveTypedReceiver for BufferUnorderedAsync +where + M: Message, + R: Message, + E: crate::Error, { fn poll_events(&self, ctx: &mut Context<'_>) -> Poll> { let poll = self.srx.lock().poll_recv(ctx); @@ -218,4 +237,4 @@ impl ReciveTypedReceiver for BufferUnorderedAsync Poll::Ready(None) => Poll::Ready(Event::Exited), } } -} \ No newline at end of file +} diff --git a/src/receivers/buffer_unordered/sync.rs b/src/receivers/buffer_unordered/sync.rs index 92756a3..fd55f22 100644 --- a/src/receivers/buffer_unordered/sync.rs +++ b/src/receivers/buffer_unordered/sync.rs @@ -7,20 +7,22 @@ use std::{ task::{Context, Poll}, }; -use crate::{receiver::{Action, Event, ReceiverStats, ReciveTypedReceiver, SendUntypedReceiver}, receivers::Request}; +use crate::{ + receiver::{Action, Event, ReceiverStats, ReciveTypedReceiver, SendUntypedReceiver}, + receivers::Request, +}; use anyhow::Result; -use futures::{Future, StreamExt, stream::FuturesUnordered}; +use futures::{stream::FuturesUnordered, Future, StreamExt}; use super::{BufferUnorderedConfig, BufferUnorderedStats}; use crate::{ builder::ReceiverSubscriberBuilder, receiver::{SendError, SendTypedReceiver}, - Handler, Bus, Message, Untyped, + Bus, Handler, Message, Untyped, }; use parking_lot::Mutex; use tokio::sync::mpsc; - fn buffer_unordered_poller( mut rx: mpsc::UnboundedReceiver>, bus: Bus, @@ -33,7 +35,7 @@ where T: Handler + 'static, M: Message, R: Message, - E: crate::Error + E: crate::Error, { let ut = ut.downcast::().unwrap(); let mut queue = FuturesUnordered::new(); @@ -46,33 +48,33 @@ where if !rx_closed && !need_flush && !need_sync { while queue.len() < cfg.max_parallel { match rx.poll_recv(cx) { - Poll::Ready(Some(a)) => { - match a { - Request::Request(mid, msg) => { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.parallel.fetch_add(1, Ordering::Relaxed); + Poll::Ready(Some(a)) => match a { + Request::Request(mid, msg) => { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); - let bus = bus.clone(); - let ut = ut.clone(); - queue.push( tokio::task::spawn_blocking(move || (mid, ut.handle(msg, &bus)))); - }, - Request::Action(Action::Flush) => need_flush = true, - Request::Action(Action::Sync) => need_sync = true, - Request::Action(Action::Close) => rx.close(), - _ => unimplemented!() + let bus = bus.clone(); + let ut = ut.clone(); + queue.push(tokio::task::spawn_blocking(move || { + (mid, ut.handle(msg, &bus)) + })); } + Request::Action(Action::Flush) => need_flush = true, + Request::Action(Action::Sync) => need_sync = true, + Request::Action(Action::Close) => rx.close(), + _ => unimplemented!(), }, Poll::Ready(None) => { need_sync = true; rx_closed = true; - }, + } Poll::Pending => break, } } } let queue_len = queue.len(); - + loop { if queue_len != 0 { loop { @@ -80,7 +82,7 @@ where Poll::Pending => return Poll::Pending, Poll::Ready(Some(Ok((mid, res)))) => { stx.send(Event::Response(mid, res)).ok(); - }, + } Poll::Ready(None) => break, _ => {} } @@ -98,7 +100,7 @@ where Poll::Pending => { sync_future = Some(fut); return Poll::Pending; - }, + } Poll::Ready(res) => { need_sync = false; stx.send(Event::Synchronized(res)).ok(); @@ -108,7 +110,9 @@ where let ut = ut.clone(); let bus_clone = bus.clone(); sync_future.replace(Box::pin(async move { - tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await.unwrap() + tokio::task::spawn_blocking(move || ut.sync(&bus_clone)) + .await + .unwrap() })); } } else { @@ -117,16 +121,20 @@ where } if queue_len == queue.len() { - return if rx_closed { Poll::Ready(()) } else { Poll::Pending }; + return if rx_closed { + Poll::Ready(()) + } else { + Poll::Pending + }; } }) } -pub struct BufferUnorderedSync - where - M: Message, - R: Message, - E: crate::Error +pub struct BufferUnorderedSync +where + M: Message, + R: Message, + E: crate::Error, { tx: mpsc::UnboundedSender>, stats: Arc, @@ -134,15 +142,22 @@ pub struct BufferUnorderedSync } impl ReceiverSubscriberBuilder for BufferUnorderedSync - where - T: Handler + 'static, - R: Message, - M: Message, - E: crate::Error +where + T: Handler + 'static, + R: Message, + M: Message, + E: crate::Error, { type Config = BufferUnorderedConfig; - fn build(cfg: Self::Config) -> (Self, Box Box Pin + Send>>>>) { + fn build( + cfg: Self::Config, + ) -> ( + Self, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { let stats = Arc::new(BufferUnorderedStats { buffer: AtomicU64::new(0), buffer_total: AtomicU64::new(cfg.buffer_size as _), @@ -156,40 +171,48 @@ impl ReceiverSubscriberBuilder for BufferUnorderedSync(rx, bus, ut, stats_clone, cfg, stx)) - as Pin + Send>> + Box::pin(buffer_unordered_poller::( + rx, + bus, + ut, + stats_clone, + cfg, + stx, + )) as Pin + Send>> }) as Box Pin + Send>>> }); - (BufferUnorderedSync:: { - tx, - stats, - srx: Mutex::new(srx), - }, poller) + ( + BufferUnorderedSync:: { + tx, + stats, + srx: Mutex::new(srx), + }, + poller, + ) } } -impl SendUntypedReceiver for BufferUnorderedSync - where - M: Message, - R: Message, - E: crate::Error +impl SendUntypedReceiver for BufferUnorderedSync +where + M: Message, + R: Message, + E: crate::Error, { fn send(&self, msg: Action) -> Result<(), SendError> { match self.tx.send(Request::Action(msg)) { Ok(_) => Ok(()), Err(mpsc::error::SendError(Request::Action(msg))) => Err(SendError::Closed(msg)), - _ => unimplemented!() + _ => unimplemented!(), } } } - -impl SendTypedReceiver for BufferUnorderedSync - where - M: Message, - R: Message, - E: crate::Error +impl SendTypedReceiver for BufferUnorderedSync +where + M: Message, + R: Message, + E: crate::Error, { fn send(&self, mid: u64, m: M) -> Result<(), SendError> { match self.tx.send(Request::Request(mid, m)) { @@ -199,16 +222,16 @@ impl SendTypedReceiver for BufferUnorderedSync Ok(()) } Err(mpsc::error::SendError(Request::Request(_, msg))) => Err(SendError::Closed(msg)), - _ => unimplemented!() + _ => unimplemented!(), } } } -impl ReciveTypedReceiver for BufferUnorderedSync - where - M: Message, - R: Message, - E: crate::Error +impl ReciveTypedReceiver for BufferUnorderedSync +where + M: Message, + R: Message, + E: crate::Error, { fn poll_events(&self, ctx: &mut Context<'_>) -> Poll> { let poll = self.srx.lock().poll_recv(ctx); @@ -218,4 +241,4 @@ impl ReciveTypedReceiver for BufferUnorderedSync Poll::Ready(None) => Poll::Ready(Event::Exited), } } -} \ No newline at end of file +} diff --git a/src/receivers/mod.rs b/src/receivers/mod.rs index dcf850b..3daa2cb 100644 --- a/src/receivers/mod.rs +++ b/src/receivers/mod.rs @@ -4,25 +4,20 @@ mod buffer_unordered; // mod synchronize_batched; // mod synchronized; - // mod mpsc; // mod mpsc { // pub use super::mpsc_futures::*; // } -pub use buffer_unordered::{ - BufferUnorderedAsync, BufferUnorderedConfig, - BufferUnorderedSync, -}; +pub use buffer_unordered::{BufferUnorderedAsync, BufferUnorderedConfig, BufferUnorderedSync}; use crate::receiver::Action; pub(crate) enum Request { Action(Action), - Request(u64, M) + Request(u64, M), } - // pub use buffer_unordered_batched::{ // BufferUnorderedBatchedAsync, BufferUnorderedBatchedAsyncSubscriber, BufferUnorderedBatchedConfig, // BufferUnorderedBatchedSync, BufferUnorderedBatchedSyncSubscriber,