commit 7240aa2b749c431e414ff5318d921307157a07d4 Author: Andrey Tkachenko Date: Thu Dec 17 17:35:11 2020 +0400 MessageBus: Initial diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7c571b7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "messagebus" +version = "0.4.2" +authors = ["Andrey Tkachenko "] +repository = "https://github.com/andreytkachenko/messagebus.git" +keywords = ["futures", "async", "tokio", "message", "bus"] +categories = ["network-programming", "asynchronous"] +description = "MessageBus allows you intercommunicate with messages between modules" +license = "MIT OR Apache-2.0" +exclude = [".gitignore", ".cargo/config", ".github/**", "codecov.yml"] +edition = "2018" + +[dependencies] +tokio = { version = "0.2", features = ["parking_lot", "rt-threaded", "sync", "stream", "blocking"] } +parking_lot = "0.11.1" +async-trait = "0.1.42" +futures = "0.3.8" +anyhow = "1.0.34" +crossbeam = "0.8.0" + +[dev-dependencies] +tokio = { version = "0.2", features = ["parking_lot", "rt-threaded", "sync", "stream", "macros"] } diff --git a/README.md b/README.md new file mode 100644 index 0000000..ac9a73e --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +
+

Message Bus

+

+ Async Message Bus for Rust +

+

+

+ +Inspired by Actix + +### Basics +1. Can deliver messages between actors using receivers (usually a queue implementations) +2. Messages distincts and delivers by TypeId +3. Messages delivers in a broadcast fashion to many receivers (Cloned) +4. There are different kind of receivers implemented: + - BufferUnordered Receiver (in sync and async version depending by handler) + - Synchronized (also sync and async) if receiving part needs syncronization + - SynchronizeBuffered (also sync and async) +here are the implmented handlers definitions: +```rust + +// Handler is Sync and we can spawn many of concurrent tasks +pub trait Handler: Send + Sync { + fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; + fn sync(&self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +#[async_trait] +pub trait AsyncHandler: Send + Sync { + async fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +// Handler is not Sync and we cannot spawn many of concurrent tasks same time (uses synchronization primitives such as Mutex or RwLock) +pub trait SynchronizedHandler: Send { + fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +#[async_trait] +pub trait AsyncSynchronizedHandler: Send { + async fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +// Handler is not Sync and handler will process items in batched mode +pub trait BatchSynchronizedHandler: Send { + fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +#[async_trait] +pub trait AsyncBatchSynchronizedHandler: Send { + async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> {Ok(())} +} + +``` +4. Handler Kinds: + 1. No Synchronization needed (Handler is `Send` + `Sync`) + * Not batched operations **(implemented)** + - sync (spawn_blocking) + - async (spawn) + * Batched + - sync (spawn_blocking) + - async (spawn) + 2. Synchronization needed (Handler is `Sync` + `!Send`) + * Not batched operations **(implemented)** + - sync (spawn_blocking) + - async (spawn) + * Batched **(implemented)** + - sync (spawn_blocking) + - async (spawn) + 3. Synchronization needed and thread dedicated (Handler is `!Sync` + `!Send`) + * Not batched operations + - sync (spawn_blocking) + - async (spawn) + * Batched + - sync (spawn_blocking) + - async (spawn) + +5. Example: +```rust +use messagebus::{Bus, AsyncHandler, Result as MbusResult, receivers}; +use async_trait::async_trait; + +struct TmpReceiver; + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + println!("---> i32 {}", msg); + + bus.send(2i64).await?; + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: i64, _bus: &Bus) -> MbusResult { + println!("---> i64 {}", msg); + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(1i32).await.unwrap(); + poller.await +} +``` \ No newline at end of file diff --git a/examples/demo.rs b/examples/demo.rs new file mode 100644 index 0000000..c83d523 --- /dev/null +++ b/examples/demo.rs @@ -0,0 +1,79 @@ +use messagebus::{receivers, Bus, Handler, Result as MbusResult}; + +struct TmpReceiver; +struct TmpReceiver2; + +impl Handler for TmpReceiver { + fn handle(&self, msg: f32, bus: &Bus) -> MbusResult { + bus.try_send(1u16).unwrap(); + + println!("---> f32 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: u16, bus: &Bus) -> MbusResult { + bus.try_send(1u32).unwrap(); + println!("---> u16 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: u32, bus: &Bus) -> MbusResult { + bus.try_send(2i32).unwrap(); + println!("---> u32 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + bus.try_send(3i16).unwrap(); + println!("---> i32 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { + println!("---> i16 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver2 { + fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + bus.try_send(3i16).unwrap(); + println!("---> 2 i32 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver2 { + fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { + println!("---> 2 i16 {}", msg); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .register(TmpReceiver2) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(32f32).await.unwrap(); + poller.await +} diff --git a/examples/demo_async.rs b/examples/demo_async.rs new file mode 100644 index 0000000..6330cbb --- /dev/null +++ b/examples/demo_async.rs @@ -0,0 +1,93 @@ +use async_trait::async_trait; +use messagebus::{receivers, AsyncHandler, Bus, Handler, Result as MbusResult}; + +struct TmpReceiver; +struct TmpReceiver2; + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: f32, bus: &Bus) -> MbusResult { + bus.send(1u16).await?; + + println!("---> f32 {}", msg); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: u16, bus: &Bus) -> MbusResult { + bus.send(2u32).await?; + println!("---> u16 {}", msg); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: u32, bus: &Bus) -> MbusResult { + bus.send(3i32).await?; + println!("---> u32 {}", msg); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + bus.send(4i16).await?; + println!("---> i32 {}", msg); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver { + async fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { + println!("---> i16 {}", msg); + + Ok(()) + } +} + +#[async_trait] +impl AsyncHandler for TmpReceiver2 { + async fn handle(&self, msg: i32, bus: &Bus) -> MbusResult { + bus.send(5i16).await?; + println!("---> 2 i32 {}", msg); + + Ok(()) + } +} + +impl Handler for TmpReceiver2 { + fn handle(&self, msg: i16, _bus: &Bus) -> MbusResult { + println!("---> 2 i16 {}", msg); + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .register(TmpReceiver2) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(0f32).await.unwrap(); + poller.await +} diff --git a/examples/demo_backpressure.rs b/examples/demo_backpressure.rs new file mode 100644 index 0000000..8fe85a3 --- /dev/null +++ b/examples/demo_backpressure.rs @@ -0,0 +1,63 @@ +use messagebus::{receivers, Bus, Handler, Result as MbusResult}; + +struct TmpReceiver; + +impl Handler for TmpReceiver { + fn handle(&self, msg: f32, _bus: &Bus) -> MbusResult { + println!("---> f32 {}", msg); + + std::thread::sleep(std::time::Duration::from_secs(1)); + + println!("done"); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver) + .subscribe::>(receivers::BufferUnorderedConfig { + buffer_size: 1, + max_parallel: 1, + }) + .done() + .build(); + + println!("sending 1"); + b.send(32f32).await.unwrap(); + + println!("sending 2"); + b.send(32f32).await.unwrap(); + + println!("sending 3"); + b.send(32f32).await.unwrap(); + + println!("sending 4"); + b.send(32f32).await.unwrap(); + + println!("sending 5"); + b.send(32f32).await.unwrap(); + + println!("sending 6"); + b.send(32f32).await.unwrap(); + + println!("sending 7"); + b.send(32f32).await.unwrap(); + + println!("sending 8"); + b.send(32f32).await.unwrap(); + + println!("sending 9"); + b.send(32f32).await.unwrap(); + + println!("sending 10"); + b.send(32f32).await.unwrap(); + + println!("sending 11"); + b.send(32f32).await.unwrap(); + + println!("finish"); + + poller.await; +} diff --git a/examples/demo_slow.rs b/examples/demo_slow.rs new file mode 100644 index 0000000..ffcedaa --- /dev/null +++ b/examples/demo_slow.rs @@ -0,0 +1,46 @@ +use messagebus::{receivers, Bus, Handler, Result as MbusResult}; + +struct TmpReceiver; + +impl Handler for TmpReceiver { + fn handle(&self, msg: f32, _bus: &Bus) -> MbusResult { + println!("---> f32 {}", msg); + + std::thread::sleep(std::time::Duration::from_secs(5)); + + println!("done"); + + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: u16, _bus: &Bus) -> MbusResult { + println!("---> u16 {}", msg); + Ok(()) + } +} + +impl Handler for TmpReceiver { + fn handle(&self, msg: u32, _bus: &Bus) -> MbusResult { + println!("---> u32 {}", msg); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register(TmpReceiver) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(32f32).await.unwrap(); + b.send(11u16).await.unwrap(); + b.send(32u32).await.unwrap(); + + poller.await +} diff --git a/examples/non_sync.rs b/examples/non_sync.rs new file mode 100644 index 0000000..128715e --- /dev/null +++ b/examples/non_sync.rs @@ -0,0 +1,57 @@ +use messagebus::{receivers, Bus, Result as MbusResult, SynchronizedHandler}; +use receivers::SynchronizedConfig; + +struct TmpReceiver; + +impl SynchronizedHandler for TmpReceiver { + fn handle(&mut self, msg: f32, _bus: &Bus) -> MbusResult { + println!("---> f32 {}", msg); + + std::thread::sleep(std::time::Duration::from_secs(1)); + + println!("done"); + Ok(()) + } +} + +impl SynchronizedHandler for TmpReceiver { + fn handle(&mut self, msg: i16, _bus: &Bus) -> MbusResult { + println!("---> i16 {}", msg); + + std::thread::sleep(std::time::Duration::from_secs(1)); + + println!("done"); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register_unsync(TmpReceiver) + .subscribe::>(SynchronizedConfig { buffer_size: 1 }) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + b.send(12.0f32).await.unwrap(); + b.send(1i16).await.unwrap(); + + println!("finish"); + + poller.await; +} diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..c9237bf --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,160 @@ +use std::{any::TypeId, collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc}; + +use futures::{Future, FutureExt}; +use receiver::ReceiverTrait; +use tokio::sync::Mutex; + +use crate::{ + receiver::{self, Receiver}, + Bus, BusInner, Message, Untyped, +}; + +pub trait ReceiverSubscriber { + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ); +} + +pub trait ReceiverSubscriberBuilder { + type Entry: ReceiverSubscriber; + type Config: Default; + + fn build(cfg: Self::Config) -> Self::Entry; +} + +pub struct SyncEntry; +pub struct UnsyncEntry; + +#[must_use] +pub struct RegisterEntry { + item: Untyped, + builder: BusBuilder, + receivers: HashMap< + TypeId, + Vec<( + Receiver, + Box< + dyn FnOnce( + Untyped, + ) + -> Box Pin + Send>>>, + >, + )>, + >, + _m: PhantomData<(K, T)>, +} + +impl RegisterEntry { + pub fn done(self) -> BusBuilder { + let mut builder = self.builder; + + for (tid, v) in self.receivers { + for (r, poller) in v { + let poller = poller(self.item.clone()); + + builder.add_recevier((tid, r), poller); + } + } + + builder + } +} + +impl RegisterEntry { + pub fn subscribe(mut self, cfg: R::Config) -> Self + where + T: Send + 'static, + M: Message + 'static, + R: ReceiverSubscriberBuilder + 'static, + { + let (inner, poller) = R::build(cfg).subscribe(); + + let receiver = Receiver::new(inner); + self.receivers + .entry(TypeId::of::()) + .or_insert_with(Vec::new) + .push((receiver, poller)); + + self + } +} + +impl RegisterEntry { + pub fn subscribe(mut self, cfg: R::Config) -> Self + where + T: Send + 'static, + M: Message + 'static, + R: ReceiverSubscriberBuilder + 'static, + { + let (inner, poller) = R::build(cfg).subscribe(); + + let receiver = Receiver::new(inner); + self.receivers + .entry(TypeId::of::()) + .or_insert_with(Vec::new) + .push((receiver, poller)); + + self + } +} + +pub struct BusBuilder { + receivers: Vec<(TypeId, Receiver)>, + pollings: Vec Pin + Send>>>>, +} + +impl BusBuilder { + pub fn new() -> Self { + Self { + receivers: Vec::new(), + pollings: Vec::new(), + } + } + + pub fn register(self, item: T) -> RegisterEntry { + RegisterEntry { + item: Arc::new(item) as Untyped, + builder: self, + receivers: HashMap::new(), + _m: Default::default(), + } + } + + pub fn register_unsync(self, item: T) -> RegisterEntry { + RegisterEntry { + item: Arc::new(Mutex::new(item)) as Untyped, + builder: self, + receivers: HashMap::new(), + _m: Default::default(), + } + } + + pub fn add_recevier( + &mut self, + val: (TypeId, Receiver), + poller: Box Pin + Send>>>, + ) { + self.receivers.push(val); + self.pollings.push(poller); + } + + pub fn build(self) -> (Bus, impl Future) { + let bus = Bus { + inner: Arc::new(BusInner::new(self.receivers)), + }; + + let mut futs = Vec::with_capacity(self.pollings.len()); + for poller in self.pollings { + futs.push(tokio::task::spawn(poller(bus.clone()))); + } + + let poller = futures::future::join_all(futs).map(|_| ()).map(|_| ()); + + (bus, poller) + } +} diff --git a/src/envelop.rs b/src/envelop.rs new file mode 100644 index 0000000..7c365e9 --- /dev/null +++ b/src/envelop.rs @@ -0,0 +1,102 @@ +use core::any::{self, Any}; +use core::fmt; +// use erased_serde::{Deserializer, Serialize}; + +pub trait Message: Any + fmt::Debug/*Serialize + for<'a> Deserializer<'a> + */ + Unpin + Clone + Send + Sync + 'static {} +impl Message for T {} + +trait SafeMessage: Any + fmt::Debug/*+ Serialize + for<'a> Deserializer<'a>*/ + Unpin + Send + Sync + 'static { + fn type_name(&self) -> &'static str; + fn clone_boxed(&self) -> Box; +} + +impl SafeMessage for T { + fn type_name(&self) -> &'static str { + any::type_name::() + } + + fn clone_boxed(&self) -> Box { + Box::new(self.clone()) + } +} + +// pub struct BoxedEnvelop { +// inner: Box, +// } + +// impl BoxedEnvelop { +// pub fn from_message(m: M) -> Self { +// Self { +// inner: Box::new(m) +// } +// } + +// pub fn as_ref(&self) -> Envelop<'_> { +// Envelop { inner: &*self.inner } +// } + +// pub fn downcast(self) -> Option> { +// if (*self.inner).type_id() == TypeId::of::() { +// unsafe { +// let raw: *mut dyn SafeMessage = Box::into_raw(self.inner); + +// Some(Box::from_raw(raw as *mut T)) +// } +// } else { +// None +// } +// } +// } + +#[derive(Copy, Clone)] +pub struct Envelop<'inner> { + inner: &'inner dyn SafeMessage, +} + +impl<'inner> fmt::Debug for Envelop<'inner> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Envelop(")?; + self.inner.fmt(f)?; + write!(f, ")")?; + + Ok(()) + } +} + +impl<'inner> Envelop<'inner> { + // pub fn new(inner: &'inner T) -> Self { + // Self { inner } + // } + + // #[inline] + // pub fn downcast_to(&self) -> Option<&T> { + // if self.inner.type_id() == TypeId::of::() { + // unsafe { Some(&*(self.inner as *const dyn SafeMessage as *const T)) } + // } else { + // None + // } + // } + + // #[inline] + // pub fn type_id(&self) -> TypeId { + // self.inner.type_id() + // } + + // #[inline] + // pub fn type_name(&self) -> &'static str { + // self.inner.type_name() + // } + + // #[inline] + // pub fn clone_boxed(&self) -> BoxedEnvelop { + // BoxedEnvelop { + // inner: self.inner.clone_boxed(), + // } + // } +} + +// impl<'inner> serde::Serialize for Envelop<'inner> { +// fn serialize(&self, serializer: S) -> Result { +// erased_serde::serialize(self.inner, serializer) +// } +// } diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..7facd70 --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,92 @@ +use crate::{Bus, Message}; +use async_trait::async_trait; + +pub trait Handler: Send + Sync { + fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; + fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait AsyncHandler: Send + Sync { + async fn handle(&self, msg: M, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +pub trait SynchronizedHandler: Send { + fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait AsyncSynchronizedHandler: Send { + async fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +pub trait BatchHandler: Send + Sync { + fn handle(&self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait AsyncBatchHandler: Send + Sync { + async fn handle(&self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +pub trait BatchSynchronizedHandler: Send { + fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait AsyncBatchSynchronizedHandler: Send { + async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +pub trait LocalHandler { + fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait LocalAsyncHandler { + async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +pub trait LocalBatchHandler { + fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +pub trait LocalAsyncBatchHandler { + async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ac7dfe7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,126 @@ +mod builder; +mod envelop; +mod handler; +pub mod msgs; +mod receiver; +pub mod receivers; +mod trait_object; +mod utils; + +use builder::BusBuilder; +pub use envelop::Message; +pub use handler::*; +pub use receiver::SendError; +use receiver::{Receiver, ReceiverStats}; +use utils::binary_search_range_by_key; + +use core::any::{Any, TypeId}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +pub type Untyped = Arc; +pub type Result = anyhow::Result<()>; + +pub struct BusInner { + receivers: Vec<(TypeId, Receiver)>, + closed: AtomicBool, +} + +impl BusInner { + pub(crate) fn new(mut receivers: Vec<(TypeId, Receiver)>) -> Self { + receivers.sort_unstable_by_key(|(k, _)| *k); + + Self { + receivers, + closed: AtomicBool::new(false), + } + } + + pub fn close(&self) { + self.closed.store(true, Ordering::SeqCst); + + for (_, r) in &self.receivers { + r.close(); + } + } + + pub async fn sync(&self) { + for (_, r) in &self.receivers { + r.sync().await; + } + } + + pub fn stats(&self) -> impl Iterator + '_ { + self.receivers.iter().map(|(_, r)| r.stats()) + } + + pub fn try_send(&self, msg: M) -> core::result::Result<(), SendError> { + if self.closed.load(Ordering::SeqCst) { + println!("Bus closed. Skipping send!"); + return Ok(()); + } + + let tid = TypeId::of::(); + let range = binary_search_range_by_key(&self.receivers, &tid, |(k, _)| *k); + + for i in (range.start + 1)..range.end { + self.receivers[i].1.try_broadcast(msg.clone())?; + } + + if let Some((_, r)) = self.receivers.get(range.start) { + r.try_broadcast(msg.clone())?; + } else { + println!("Unhandled message {:?}", core::any::type_name::()); + } + + Ok(()) + } + + #[inline] + pub fn send_blocking(&self, msg: M) -> core::result::Result<(), SendError> { + futures::executor::block_on(self.send(msg)) + } + + pub async fn send(&self, msg: M) -> core::result::Result<(), SendError> { + if self.closed.load(Ordering::SeqCst) { + return Err(SendError::Closed(msg)); + } + + let tid = TypeId::of::(); + let range = binary_search_range_by_key(&self.receivers, &tid, |(k, _)| *k); + + for i in (range.start + 1)..range.end { + self.receivers[i].1.broadcast(msg.clone()).await?; + } + + if let Some((_, r)) = self.receivers.get(range.start) { + r.broadcast(msg.clone()).await?; + } else { + println!("Unhandled message {:?}", core::any::type_name::()); + } + + Ok(()) + } +} + +#[derive(Clone)] +pub struct Bus { + inner: Arc, +} + +impl core::ops::Deref for Bus { + type Target = BusInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl Bus { + #[inline] + pub fn build() -> BusBuilder { + BusBuilder::new() + } +} diff --git a/src/msgs.rs b/src/msgs.rs new file mode 100644 index 0000000..209146e --- /dev/null +++ b/src/msgs.rs @@ -0,0 +1,10 @@ +use std::sync::Arc; + +#[derive(Clone, Debug)] +pub struct Error(pub Arc); + +impl> From for Error { + fn from(e: T) -> Self { + Self(Arc::new(e.into())) + } +} diff --git a/src/receiver.rs b/src/receiver.rs new file mode 100644 index 0000000..b5d7907 --- /dev/null +++ b/src/receiver.rs @@ -0,0 +1,212 @@ +use crate::{trait_object::TraitObject, Bus, Message}; +use core::{ + any::TypeId, + fmt, + future::Future, + marker::PhantomData, + mem, + pin::Pin, + task::{Context, Poll}, +}; +use futures::future::poll_fn; +use std::{borrow::Cow, sync::Arc}; + +pub struct AnyReceiver<'a> { + dyn_typed_receiver_trait_object: TraitObject, + type_id: TypeId, + _m: PhantomData<&'a usize>, +} + +impl<'a> AnyReceiver<'a> { + pub fn new + 'static>(rcvr: &'a R) -> Self { + let trcvr = rcvr as &(dyn TypedReceiver); + + Self { + dyn_typed_receiver_trait_object: unsafe { mem::transmute(trcvr) }, + type_id: TypeId::of::>(), + _m: Default::default(), + } + } + + pub fn dyn_typed_receiver(&'a self) -> &'a dyn TypedReceiver { + assert_eq!(self.type_id, TypeId::of::>()); + + unsafe { mem::transmute(self.dyn_typed_receiver_trait_object) } + } +} + +pub enum SendError { + Full(M), + Closed(M), +} + +impl fmt::Debug for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendError::Full(m) => write!(f, "SendError::Full({:?})", m)?, + SendError::Closed(m) => write!(f, "SendError::Closed({:?})", m)?, + } + + Ok(()) + } +} + +impl fmt::Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendError::Full(m) => write!(f, "SendError::Full({:?})", m)?, + SendError::Closed(m) => write!(f, "SendError::Closed({:?})", m)?, + } + + Ok(()) + } +} + +impl std::error::Error for SendError {} + +#[derive(Debug, Clone)] +pub struct ReceiverStats { + pub name: Cow<'static, str>, + pub fields: Vec<(Cow<'static, str>, u64)>, +} + +impl fmt::Display for ReceiverStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "-- {}: {{ ", self.name)?; + + for (idx, (k, v)) in self.fields.iter().enumerate() { + if idx != 0 { + write!(f, ", ")?; + } + + write!(f, "{}: {}", k, v)?; + } + + write!(f, " }}")?; + Ok(()) + } +} + +pub trait TypedReceiver: Sync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()>; + fn try_send(&self, msg: M) -> Result<(), SendError>; +} + +pub trait ReceiverTrait: Send + Sync { + fn typed(&self) -> AnyReceiver<'_>; + fn type_id(&self) -> TypeId; + fn close(&self); + fn stats(&self) -> ReceiverStats; + fn sync(&self); + fn poll_synchronized(&self, ctx: &mut Context<'_>) -> Poll<()>; +} + +pub trait ReceiverPollerBuilder { + fn build(bus: Bus) -> Box>; +} + +pub struct Receiver { + inner: Arc, +} + +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Receiver({:?})", self.inner.type_id())?; + Ok(()) + } +} + +impl core::cmp::PartialEq for Receiver { + fn eq(&self, other: &Receiver) -> bool { + self.inner.type_id() == other.inner.type_id() + } +} + +impl core::cmp::Eq for Receiver {} + +pub struct ReceiverPoller<'a, M: Message> { + inner: &'a dyn ReceiverTrait, + msg: Option, +} + +impl<'a, M: Message> Unpin for ReceiverPoller<'a, M> {} +impl<'a, M: Message> Future for ReceiverPoller<'a, M> { + type Output = Result<(), SendError>; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let mut msg = if let Some(msg) = this.msg.take() { + msg + } else { + return Poll::Ready(Ok(())); + }; + + let any_receiver = this.inner.typed(); + let receiver = any_receiver.dyn_typed_receiver::(); + + loop { + match receiver.poll_ready(ctx) { + Poll::Ready(_) => (), + Poll::Pending => { + this.msg = Some(msg); + return Poll::Pending; + } + } + + msg = match receiver.try_send(msg) { + Ok(_) => break Poll::Ready(Ok(())), + Err(SendError::Full(m)) => m, + Err(err) => break Poll::Ready(Err(err)), + } + } + } +} + +impl Receiver { + #[inline] + pub(crate) fn new(inner: Arc) -> Self { + Self { inner } + } + + #[inline] + pub fn type_id(&self) -> TypeId { + self.inner.type_id() + } + + #[inline] + pub fn broadcast( + &self, + msg: M, + ) -> impl Future>> + '_ { + ReceiverPoller { + inner: self.inner.as_ref(), + msg: Some(msg), + } + } + + #[inline] + pub fn try_broadcast(&self, msg: M) -> Result<(), SendError> { + let any_receiver = self.inner.typed(); + let receiver = any_receiver.dyn_typed_receiver::(); + + receiver.try_send(msg) + } + + #[inline] + pub fn close(&self) { + self.inner.close(); + } + + #[inline] + pub fn stats(&self) -> ReceiverStats { + self.inner.stats() + } + + #[inline] + pub fn sync(&self) -> impl Future + '_ { + self.inner.sync(); + + poll_fn(move |ctx| self.inner.poll_synchronized(ctx)) + } +} diff --git a/src/receivers/buffer_unordered/async.rs b/src/receivers/buffer_unordered/async.rs new file mode 100644 index 0000000..155296d --- /dev/null +++ b/src/receivers/buffer_unordered/async.rs @@ -0,0 +1,203 @@ +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; + +use super::{BufferUnorderedConfig, BufferUnorderedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + AsyncHandler, Bus, Message, Untyped, +}; + +pub struct BufferUnorderedAsyncSubscriber +where + T: AsyncHandler + 'static, + M: Message, +{ + cfg: BufferUnorderedConfig, + _m: PhantomData<(T, M)>, +} + +impl ReceiverSubscriber for BufferUnorderedAsyncSubscriber +where + T: AsyncHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let stats = Arc::new(BufferUnorderedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + parallel: AtomicU64::new(0), + parallel_total: AtomicU64::new(cfg.max_parallel as _), + }); + + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let arc = Arc::new(BufferUnorderedAsync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: BufferUnorderedConfig, +) where + T: AsyncHandler + 'static, + M: Message, +{ + let ut = ut.downcast::().unwrap(); + + let mut x = rx + .map(|msg| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); + let bus = bus.clone(); + let ut = ut.clone(); + + tokio::task::spawn(async move { ut.handle(msg, &bus).await }) + }) + .buffer_unordered(cfg.max_parallel); + + while let Some(err) = x.next().await { + stats.parallel.fetch_sub(1, Ordering::Relaxed); + + match err { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn(async move { ut.sync(&bus_clone).await }).await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!( + "[EXIT] BufferUnorderedAsync<{}>", + std::any::type_name::() + ); +} + +pub struct BufferUnorderedAsync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for BufferUnorderedAsync +where + T: AsyncHandler + 'static, + M: Message, +{ + type Entry = BufferUnorderedAsyncSubscriber; + type Config = BufferUnorderedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + BufferUnorderedAsyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for BufferUnorderedAsync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for BufferUnorderedAsync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn close(&self) { + self.tx.close(); + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ( + "parallel".into(), + self.stats.parallel.load(Ordering::SeqCst), + ), + ( + "parallel_total".into(), + self.stats.parallel_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/buffer_unordered/mod.rs b/src/receivers/buffer_unordered/mod.rs new file mode 100644 index 0000000..b5fed0b --- /dev/null +++ b/src/receivers/buffer_unordered/mod.rs @@ -0,0 +1,30 @@ +mod r#async; +mod sync; + +use std::sync::atomic::AtomicU64; + +pub use r#async::{BufferUnorderedAsync, BufferUnorderedAsyncSubscriber}; +pub use sync::{BufferUnorderedSync, BufferUnorderedSyncSubscriber}; + +#[derive(Debug)] +pub struct BufferUnorderedStats { + pub buffer: AtomicU64, + pub buffer_total: AtomicU64, + pub parallel: AtomicU64, + pub parallel_total: AtomicU64, +} + +#[derive(Copy, Clone, Debug)] +pub struct BufferUnorderedConfig { + pub buffer_size: usize, + pub max_parallel: usize, +} + +impl Default for BufferUnorderedConfig { + fn default() -> Self { + Self { + buffer_size: 8, + max_parallel: 8, + } + } +} diff --git a/src/receivers/buffer_unordered/sync.rs b/src/receivers/buffer_unordered/sync.rs new file mode 100644 index 0000000..6984cd0 --- /dev/null +++ b/src/receivers/buffer_unordered/sync.rs @@ -0,0 +1,203 @@ +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use super::{BufferUnorderedConfig, BufferUnorderedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, Handler, Message, Untyped, +}; + +pub struct BufferUnorderedSyncSubscriber +where + T: Handler + 'static, + M: Message, +{ + cfg: BufferUnorderedConfig, + _m: PhantomData<(M, T)>, +} + +impl ReceiverSubscriber for BufferUnorderedSyncSubscriber +where + T: Handler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(BufferUnorderedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + parallel: AtomicU64::new(0), + parallel_total: AtomicU64::new(cfg.max_parallel as _), + }); + + let arc = Arc::new(BufferUnorderedSync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: BufferUnorderedConfig, +) where + T: Handler + 'static, + M: Message, +{ + let ut = ut.downcast::().unwrap(); + + let mut x = rx + .map(|msg| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.parallel.fetch_add(1, Ordering::Relaxed); + + let bus = bus.clone(); + let ut = ut.clone(); + + tokio::task::spawn_blocking(move || ut.handle(msg, &bus)) + }) + .buffer_unordered(cfg.max_parallel); + + while let Some(err) = x.next().await { + stats.parallel.fetch_sub(1, Ordering::Relaxed); + + match err { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!( + "[EXIT] BufferUnorderedAsync<{}>", + std::any::type_name::() + ); +} + +pub struct BufferUnorderedSync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for BufferUnorderedSync +where + T: Handler + 'static, + M: Message, +{ + type Entry = BufferUnorderedSyncSubscriber; + type Config = BufferUnorderedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + BufferUnorderedSyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for BufferUnorderedSync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for BufferUnorderedSync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ( + "parallel".into(), + self.stats.parallel.load(Ordering::SeqCst), + ), + ( + "parallel_total".into(), + self.stats.parallel_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/mod.rs b/src/receivers/mod.rs new file mode 100644 index 0000000..3076c9f --- /dev/null +++ b/src/receivers/mod.rs @@ -0,0 +1,23 @@ +mod buffer_unordered; +mod mpsc_futures; +mod synchronize_batched; +mod synchronized; + +mod mpsc { + pub use super::mpsc_futures::*; +} + +pub use buffer_unordered::{ + BufferUnorderedAsync, BufferUnorderedAsyncSubscriber, BufferUnorderedConfig, + BufferUnorderedSync, BufferUnorderedSyncSubscriber, +}; + +pub use synchronized::{ + SynchronizedAsync, SynchronizedAsyncSubscriber, SynchronizedConfig, SynchronizedSync, + SynchronizedSyncSubscriber, +}; + +pub use synchronize_batched::{ + SynchronizeBatchedAsync, SynchronizeBatchedAsyncSubscriber, SynchronizeBatchedConfig, + SynchronizeBatchedSync, SynchronizeBatchedSyncSubscriber, +}; diff --git a/src/receivers/mpsc.rs b/src/receivers/mpsc.rs new file mode 100644 index 0000000..6a62d8c --- /dev/null +++ b/src/receivers/mpsc.rs @@ -0,0 +1,142 @@ +use futures::{Stream, StreamExt}; +use core::pin::Pin; +use crossbeam::queue::ArrayQueue; +use crossbeam::atomic::AtomicCell; +use core::task::{Waker, Context, Poll}; +use std::sync::{Arc, atomic::*}; + +use crate::receiver::SendError; + + + +struct ChannelInner { + queue: ArrayQueue, + send_waker: AtomicCell>>, + recv_waker: AtomicCell>>, + closed: AtomicBool, +} + +pub fn channel(buffer: usize) -> (Sender, Receiver) { + let inner = Arc::new(ChannelInner { + queue: ArrayQueue::new(buffer), + send_waker: AtomicCell::new(None), + recv_waker: AtomicCell::new(None), + closed: AtomicBool::new(false), + }); + + ( + Sender { + inner: inner.clone(), + }, + Receiver { + inner, + } + ) +} + + +#[derive(Clone)] +pub struct Sender { + inner: Arc> +} + +impl Sender { + pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { + if self.inner.closed.load(Ordering::SeqCst) { + return Poll::Ready(()); + } + + if self.inner.queue.is_full() { + self.inner.send_waker.store(Some(Box::new(cx.waker().clone()))); + } + + let mut counter = 4; + loop { + if self.inner.queue.is_full() { + if counter > 0 { + counter -= 1; + continue; + } else { + break Poll::Pending; + } + } else { + break Poll::Ready(()); + } + } + } + + pub fn try_send(&self, mut item: T) -> Result<(), SendError> { + if self.inner.closed.load(Ordering::SeqCst) { + return Err(SendError::Closed(item)); + } + + let mut counter = 0; + loop { + match self.inner.queue.push(item) { + Ok(_) => { + if let Some(waker) = self.inner.recv_waker.take() { + waker.wake(); + } + + break Ok(()); + } + + Err(inner) => { + if counter >= 4 { + break Err(SendError::Full(inner)); + } else { + item = inner; + counter += 1; + } + } + } + } + } + + pub fn close(&self) { + self.inner.closed.store(true, Ordering::SeqCst); + if let Some(waker) = self.inner.recv_waker.take() { + waker.wake(); + } + } +} + +pub struct Receiver { + inner: Arc> +} + +impl Stream for Receiver { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut counter = 0; + + loop { + match this.inner.queue.pop() { + Some(inner) => { + if let Some(waker) = this.inner.send_waker.take() { + waker.wake(); + } + + break Poll::Ready(Some(inner)); + }, + None => { + if this.inner.closed.load(Ordering::SeqCst) { + break Poll::Ready(None); + } else { + if counter == 0 { + this.inner.recv_waker.store(Some(Box::new(cx.waker().clone()))); + } + + if counter >= 8 { + break Poll::Pending; + } else { + counter += 1; + } + } + } + } + } + } +} diff --git a/src/receivers/mpsc_futures.rs b/src/receivers/mpsc_futures.rs new file mode 100644 index 0000000..cec8835 --- /dev/null +++ b/src/receivers/mpsc_futures.rs @@ -0,0 +1,99 @@ +use crate::receiver::SendError; +use core::pin::Pin; +use core::task::{Context, Poll}; +use crossbeam::queue::SegQueue; +use futures::{channel::mpsc, Stream}; +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::Waker, +}; + +pub struct State { + buffer: usize, + counter: AtomicUsize, + send_wakers: SegQueue, +} + +pub fn channel(buffer: usize) -> (Sender, Receiver) { + let state = Arc::new(State { + buffer, + counter: AtomicUsize::new(0), + send_wakers: SegQueue::new(), + }); + let (tx, rx) = mpsc::unbounded(); + + ( + Sender { + inner: tx, + state: state.clone(), + }, + Receiver { inner: rx, state }, + ) +} + +pub struct Sender { + inner: mpsc::UnboundedSender, + state: Arc, +} + +impl Sender { + pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { + if self.state.counter.load(Ordering::SeqCst) >= self.state.buffer { + self.state.send_wakers.push(cx.waker().clone()); + return Poll::Pending; + } + + Poll::Ready(()) + } + + pub fn try_send(&self, item: T) -> Result<(), SendError> { + if self.state.counter.load(Ordering::Relaxed) >= self.state.buffer { + return Err(SendError::Full(item)); + } + + self.state.counter.fetch_add(1, Ordering::SeqCst); + match self.inner.unbounded_send(item) { + Ok(_) => Ok(()), + Err(err) if err.is_full() => Err(SendError::Full(err.into_inner())), + Err(err) => Err(SendError::Closed(err.into_inner())), + } + } + + #[inline] + pub fn flush(&self) {} + + #[inline] + pub fn close(&self) { + self.inner.close_channel(); + } +} + +pub struct Receiver { + inner: mpsc::UnboundedReceiver, + state: Arc, +} + +impl Stream for Receiver { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(inner) => { + let val = this.state.counter.fetch_sub(1, Ordering::SeqCst); + + if val <= this.state.buffer { + if let Some(waker) = this.state.send_wakers.pop() { + waker.wake(); + } + } + + Poll::Ready(inner) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/receivers/synchronize_batched/async.rs b/src/receivers/synchronize_batched/async.rs new file mode 100644 index 0000000..5a18331 --- /dev/null +++ b/src/receivers/synchronize_batched/async.rs @@ -0,0 +1,220 @@ +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; +use tokio::sync::Mutex; + +use super::{SynchronizeBatchedConfig, SynchronizeBatchedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + AsyncBatchSynchronizedHandler, Bus, Message, Untyped, +}; + +pub struct SynchronizeBatchedAsyncSubscriber +where + T: AsyncBatchSynchronizedHandler + 'static, + M: Message, +{ + cfg: SynchronizeBatchedConfig, + _m: PhantomData<(T, M)>, +} + +impl ReceiverSubscriber for SynchronizeBatchedAsyncSubscriber +where + T: AsyncBatchSynchronizedHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(SynchronizeBatchedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + batch: AtomicU64::new(0), + batch_size: AtomicU64::new(cfg.batch_size as _), + }); + + let arc = Arc::new(SynchronizeBatchedAsync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + mut rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: SynchronizeBatchedConfig, +) where + T: AsyncBatchSynchronizedHandler + 'static, + M: Message, +{ + let mut batch = Vec::with_capacity(cfg.batch_size); + let ut = ut.downcast::>().unwrap(); + + while let Some(msg) = rx.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + batch.push(msg); + stats.batch.fetch_add(1, Ordering::Relaxed); + + if batch.len() >= cfg.batch_size { + let bus_clone = bus.clone(); + let ut = ut.clone(); + let msgs = batch.drain(..).collect::>(); + stats.batch.store(0, Ordering::Relaxed); + + let res = + tokio::task::spawn(async move { ut.lock().await.handle(msgs, &bus_clone).await }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + } + + if !batch.is_empty() { + let ut = ut.clone(); + let bus_clone = bus.clone(); + stats.batch.store(0, Ordering::Relaxed); + let res = + tokio::task::spawn(async move { ut.lock().await.handle(batch, &bus_clone).await }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn(async move { ut.lock().await.sync(&bus_clone).await }).await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!( + "[EXIT] SynchronizeBatchedAsync<{}>", + std::any::type_name::() + ); +} + +pub struct SynchronizeBatchedAsync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for SynchronizeBatchedAsync +where + T: AsyncBatchSynchronizedHandler + 'static, + M: Message, +{ + type Entry = SynchronizeBatchedAsyncSubscriber; + type Config = SynchronizeBatchedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + SynchronizeBatchedAsyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for SynchronizeBatchedAsync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for SynchronizeBatchedAsync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::() + } + + fn close(&self) { + self.tx.close(); + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ("batch".into(), self.stats.batch.load(Ordering::SeqCst)), + ( + "batch_size".into(), + self.stats.batch_size.load(Ordering::SeqCst), + ), + ], + } + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/synchronize_batched/mod.rs b/src/receivers/synchronize_batched/mod.rs new file mode 100644 index 0000000..dccb9ad --- /dev/null +++ b/src/receivers/synchronize_batched/mod.rs @@ -0,0 +1,31 @@ +mod r#async; +mod sync; + +use std::sync::atomic::AtomicU64; + +pub use sync::{SynchronizeBatchedSync, SynchronizeBatchedSyncSubscriber}; + +pub use r#async::{SynchronizeBatchedAsync, SynchronizeBatchedAsyncSubscriber}; + +#[derive(Debug)] +pub struct SynchronizeBatchedStats { + pub buffer: AtomicU64, + pub buffer_total: AtomicU64, + pub batch: AtomicU64, + pub batch_size: AtomicU64, +} + +#[derive(Copy, Clone, Debug)] +pub struct SynchronizeBatchedConfig { + pub buffer_size: usize, + pub batch_size: usize, +} + +impl Default for SynchronizeBatchedConfig { + fn default() -> Self { + Self { + buffer_size: 4, + batch_size: 16, + } + } +} diff --git a/src/receivers/synchronize_batched/sync.rs b/src/receivers/synchronize_batched/sync.rs new file mode 100644 index 0000000..c3ebe6e --- /dev/null +++ b/src/receivers/synchronize_batched/sync.rs @@ -0,0 +1,223 @@ +use super::{SynchronizeBatchedConfig, SynchronizeBatchedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + BatchSynchronizedHandler, Bus, Message, Untyped, +}; +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; +use tokio::sync::Mutex; + +pub struct SynchronizeBatchedSyncSubscriber +where + T: BatchSynchronizedHandler + 'static, + M: Message, +{ + cfg: SynchronizeBatchedConfig, + _m: PhantomData<(M, T)>, +} + +impl ReceiverSubscriber for SynchronizeBatchedSyncSubscriber +where + T: BatchSynchronizedHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(SynchronizeBatchedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + batch: AtomicU64::new(0), + batch_size: AtomicU64::new(cfg.batch_size as _), + }); + + let arc = Arc::new(SynchronizeBatchedSync:: { + tx, + stats: stats.clone(), + }); + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + mut rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: SynchronizeBatchedConfig, +) where + T: BatchSynchronizedHandler + 'static, + M: Message, +{ + let mut batch = Vec::with_capacity(cfg.batch_size); + let ut = ut.downcast::>().unwrap(); + + while let Some(msg) = rx.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + batch.push(msg); + stats.batch.fetch_add(1, Ordering::Relaxed); + + if batch.len() >= cfg.batch_size { + let ut = ut.clone(); + let bus_clone = bus.clone(); + let msgs = batch.drain(..).collect::>(); + stats.batch.store(0, Ordering::Relaxed); + + let res = tokio::task::spawn_blocking(move || { + let mut uut = futures::executor::block_on(ut.lock()); + uut.handle(msgs, &bus_clone) + }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + } + + if !batch.is_empty() { + let ut = ut.clone(); + let bus_clone = bus.clone(); + stats.batch.store(0, Ordering::Relaxed); + let res = tokio::task::spawn_blocking(move || { + futures::executor::block_on(ut.lock()).handle(batch, &bus_clone) + }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn_blocking(move || { + futures::executor::block_on(ut.lock()).sync(&bus_clone) + }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!( + "[EXIT] SynchronizeBatchedSync<{}>", + std::any::type_name::() + ); +} + +pub struct SynchronizeBatchedSync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for SynchronizeBatchedSync +where + T: BatchSynchronizedHandler + 'static, + M: Message, +{ + type Entry = SynchronizeBatchedSyncSubscriber; + type Config = SynchronizeBatchedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + SynchronizeBatchedSyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for SynchronizeBatchedSync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for SynchronizeBatchedSync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ("batch".into(), self.stats.batch.load(Ordering::SeqCst)), + ( + "batch_size".into(), + self.stats.batch_size.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/synchronized/async.rs b/src/receivers/synchronized/async.rs new file mode 100644 index 0000000..e4952d5 --- /dev/null +++ b/src/receivers/synchronized/async.rs @@ -0,0 +1,186 @@ +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; +use tokio::sync::Mutex; + +use super::{SynchronizedConfig, SynchronizedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + AsyncSynchronizedHandler, Bus, Message, Untyped, +}; + +pub struct SynchronizedAsyncSubscriber +where + T: AsyncSynchronizedHandler + 'static, + M: Message, +{ + cfg: SynchronizedConfig, + _m: PhantomData<(T, M)>, +} + +impl ReceiverSubscriber for SynchronizedAsyncSubscriber +where + T: AsyncSynchronizedHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(SynchronizedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + }); + + let arc = Arc::new(SynchronizedAsync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + _cfg: SynchronizedConfig, +) where + T: AsyncSynchronizedHandler + 'static, + M: Message, +{ + let ut = ut.downcast::>().unwrap(); + let mut x = rx.then(|msg| { + let bus = bus.clone(); + let ut = ut.clone(); + + tokio::task::spawn(async move { ut.lock().await.handle(msg, &bus).await }) + }); + + while let Some(err) = x.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + + match err { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn(async move { ut.lock().await.sync(&bus_clone).await }).await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!("[EXIT] SynchronizedAsync<{}>", std::any::type_name::()); +} + +pub struct SynchronizedAsync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for SynchronizedAsync +where + T: AsyncSynchronizedHandler + 'static, + M: Message, +{ + type Entry = SynchronizedAsyncSubscriber; + type Config = SynchronizedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + SynchronizedAsyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for SynchronizedAsync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for SynchronizedAsync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/synchronized/mod.rs b/src/receivers/synchronized/mod.rs new file mode 100644 index 0000000..56671b8 --- /dev/null +++ b/src/receivers/synchronized/mod.rs @@ -0,0 +1,25 @@ +mod r#async; +mod sync; + +use std::sync::atomic::AtomicU64; + +pub use sync::{SynchronizedSync, SynchronizedSyncSubscriber}; + +pub use r#async::{SynchronizedAsync, SynchronizedAsyncSubscriber}; + +#[derive(Debug)] +pub struct SynchronizedStats { + pub buffer: AtomicU64, + pub buffer_total: AtomicU64, +} + +#[derive(Copy, Clone, Debug)] +pub struct SynchronizedConfig { + pub buffer_size: usize, +} + +impl Default for SynchronizedConfig { + fn default() -> Self { + Self { buffer_size: 1 } + } +} diff --git a/src/receivers/synchronized/sync.rs b/src/receivers/synchronized/sync.rs new file mode 100644 index 0000000..30e72f2 --- /dev/null +++ b/src/receivers/synchronized/sync.rs @@ -0,0 +1,190 @@ +use super::{SynchronizedConfig, SynchronizedStats}; +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{executor::block_on, Future, StreamExt}; +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; +use tokio::sync::Mutex; + +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + msgs, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, Message, SynchronizedHandler, Untyped, +}; + +pub struct SynchronizedSyncSubscriber +where + T: SynchronizedHandler + 'static, + M: Message, +{ + cfg: SynchronizedConfig, + _m: PhantomData<(M, T)>, +} + +impl ReceiverSubscriber for SynchronizedSyncSubscriber +where + T: SynchronizedHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(SynchronizedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + }); + + let arc = Arc::new(SynchronizedSync:: { + tx, + stats: stats.clone(), + }); + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + _cfg: SynchronizedConfig, +) where + T: SynchronizedHandler + 'static, + M: Message, +{ + let ut = ut.downcast::>().unwrap(); + let mut x = rx.then(|msg| { + let ut = ut.clone(); + let bus = bus.clone(); + + tokio::task::spawn_blocking(move || block_on(ut.lock()).handle(msg, &bus)) + }); + + while let Some(err) = x.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + + match err { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + let res = tokio::task::spawn_blocking(move || { + futures::executor::block_on(ut.lock()).sync(&bus_clone) + }) + .await; + + match res { + Ok(Err(err)) => { + let _ = bus.send(msgs::Error(Arc::new(err))).await; + } + _ => (), + } + + println!( + "[EXIT] BufferUnorderedAsync<{}>", + std::any::type_name::() + ); +} + +pub struct SynchronizedSync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for SynchronizedSync +where + T: SynchronizedHandler + 'static, + M: Message, +{ + type Entry = SynchronizedSyncSubscriber; + type Config = SynchronizedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + SynchronizedSyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for SynchronizedSync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for SynchronizedSync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/trait_object.rs b/src/trait_object.rs new file mode 100644 index 0000000..b9b3372 --- /dev/null +++ b/src/trait_object.rs @@ -0,0 +1,6 @@ +#[derive(Debug, Copy, Clone)] +#[repr(C)] +pub struct TraitObject { + pub data: *mut (), + pub vtable: *mut (), +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..1e3ab7d --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,33 @@ +use core::cmp::{Ord, Ordering}; +use core::ops::Range; + +pub fn binary_search_range_by_key<'a, T, B, F>(data: &'a [T], item: &B, mut f: F) -> Range +where + F: FnMut(&'a T) -> B, + B: Ord, +{ + if let Ok(index) = data.binary_search_by_key(item, &mut f) { + let mut begin = index; + let mut end = index + 1; + + for i in (0..index).rev() { + if f(unsafe { data.get_unchecked(i) }).cmp(item) != Ordering::Equal { + break; + } + + begin = i; + } + + for i in end..data.len() { + end = i; + + if f(unsafe { data.get_unchecked(i) }).cmp(item) != Ordering::Equal { + break; + } + } + + begin..end + } else { + data.len()..data.len() + } +}