From 2d35840044d7619367124ccfe4056a21ef530e9f Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Fri, 15 Jan 2021 13:52:50 +0400 Subject: [PATCH] Local Handlers --- Cargo.toml | 14 +- examples/demo_local.rs | 41 ++++ src/builder.rs | 40 +++- src/handler.rs | 8 +- src/lib.rs | 5 +- src/receivers/buffer_unordered/async.rs | 10 +- src/receivers/buffer_unordered/sync.rs | 17 +- .../buffer_unordered_batched/async.rs | 29 +-- src/receivers/buffer_unordered_batched/mod.rs | 4 +- .../buffer_unordered_batched/sync.rs | 29 +-- src/receivers/local/async.rs | 180 ++++++++++++++++ src/receivers/local/mod.rs | 25 +++ src/receivers/local/sync.rs | 177 ++++++++++++++++ src/receivers/local_batched/async.rs | 196 ++++++++++++++++++ src/receivers/local_batched/mod.rs | 33 +++ src/receivers/local_batched/sync.rs | 193 +++++++++++++++++ src/receivers/mod.rs | 13 +- src/receivers/mpsc_futures.rs | 2 +- src/receivers/synchronize_batched/async.rs | 21 +- src/receivers/synchronize_batched/sync.rs | 31 ++- src/receivers/synchronized/async.rs | 12 +- src/receivers/synchronized/sync.rs | 18 +- src/untyped.rs | 174 ++++++++++++++++ 23 files changed, 1166 insertions(+), 106 deletions(-) create mode 100644 examples/demo_local.rs create mode 100644 src/receivers/local/async.rs create mode 100644 src/receivers/local/mod.rs create mode 100644 src/receivers/local/sync.rs create mode 100644 src/receivers/local_batched/async.rs create mode 100644 src/receivers/local_batched/mod.rs create mode 100644 src/receivers/local_batched/sync.rs create mode 100644 src/untyped.rs diff --git a/Cargo.toml b/Cargo.toml index e5dd380..af1155d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "messagebus" -version = "0.4.6" +version = "0.5.0" authors = ["Andrey Tkachenko "] repository = "https://github.com/andreytkachenko/messagebus.git" keywords = ["futures", "async", "tokio", "message", "bus"] @@ -11,12 +11,12 @@ exclude = [".gitignore", ".cargo/config", ".github/**", "codecov.yml"] edition = "2018" [dependencies] -tokio = { version = "0.2", features = ["parking_lot", "rt-threaded", "sync", "stream", "blocking"] } -parking_lot = "0.11.1" -async-trait = "0.1.42" -futures = "0.3.8" -anyhow = "1.0.34" +tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "sync"] } +parking_lot = "0.11" +async-trait = "0.1" +futures = "0.3" +anyhow = "1.0" crossbeam = "0.8.0" [dev-dependencies] -tokio = { version = "0.2", features = ["parking_lot", "rt-threaded", "sync", "stream", "macros"] } +tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "sync", "macros"] } diff --git a/examples/demo_local.rs b/examples/demo_local.rs new file mode 100644 index 0000000..a95e3ee --- /dev/null +++ b/examples/demo_local.rs @@ -0,0 +1,41 @@ +use std::rc::Rc; +use std::cell::Cell; +use async_trait::async_trait; +use messagebus::{receivers, Bus, LocalHandler, LocalAsyncHandler, Result as MbusResult}; + +struct TmpReceiver { + _inner: Rc> +} + + +#[async_trait(?Send)] +impl LocalAsyncHandler for TmpReceiver { + async fn handle(&mut self, msg: f32, bus: &Bus) -> MbusResult { + println!("---> f32 {}", msg); + + bus.send(11u16).await.unwrap(); + + Ok(()) + } +} + +impl LocalHandler for TmpReceiver { + fn handle(&mut self, msg: u16, _bus: &Bus) -> MbusResult { + println!("---> u16 {}", msg); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let (b, poller) = Bus::build() + .register_local(||TmpReceiver {_inner: Rc::new(Cell::new(12))}) + .subscribe::>(Default::default()) + .subscribe::>(Default::default()) + .done() + .build(); + + b.send(32f32).await.unwrap(); + + poller.await +} diff --git a/src/builder.rs b/src/builder.rs index c9237bf..2470dd8 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -2,7 +2,6 @@ use std::{any::TypeId, collections::HashMap, marker::PhantomData, pin::Pin, sync use futures::{Future, FutureExt}; use receiver::ReceiverTrait; -use tokio::sync::Mutex; use crate::{ receiver::{self, Receiver}, @@ -30,6 +29,8 @@ pub trait ReceiverSubscriberBuilder { pub struct SyncEntry; pub struct UnsyncEntry; +pub struct LocalEntry; + #[must_use] pub struct RegisterEntry { item: Untyped, @@ -65,6 +66,27 @@ impl RegisterEntry { } } + +impl RegisterEntry { + pub fn subscribe(mut self, cfg: R::Config) -> Self + where + T: 'static, + M: Message + 'static, + R: ReceiverSubscriberBuilder + 'static, + { + let (inner, poller) = R::build(cfg).subscribe(); + + let receiver = Receiver::new(inner); + self.receivers + .entry(TypeId::of::()) + .or_insert_with(Vec::new) + .push((receiver, poller)); + + self + } +} + + impl RegisterEntry { pub fn subscribe(mut self, cfg: R::Config) -> Self where @@ -118,7 +140,7 @@ impl BusBuilder { pub fn register(self, item: T) -> RegisterEntry { RegisterEntry { - item: Arc::new(item) as Untyped, + item: Untyped::new_rwlock(item), builder: self, receivers: HashMap::new(), _m: Default::default(), @@ -127,7 +149,19 @@ impl BusBuilder { pub fn register_unsync(self, item: T) -> RegisterEntry { RegisterEntry { - item: Arc::new(Mutex::new(item)) as Untyped, + item: Untyped::new_mutex(item), + builder: self, + receivers: HashMap::new(), + _m: Default::default(), + } + } + + pub fn register_local( + self, + item: impl FnOnce() -> T + Send + 'static, + ) -> RegisterEntry { + RegisterEntry { + item: Untyped::new_local(item), builder: self, receivers: HashMap::new(), _m: Default::default(), diff --git a/src/handler.rs b/src/handler.rs index 7facd70..d5b2bd0 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -62,15 +62,15 @@ pub trait AsyncBatchSynchronizedHandler: Send { } pub trait LocalHandler { - fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { Ok(()) } } -#[async_trait] +#[async_trait(?Send)] pub trait LocalAsyncHandler { - async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; + async fn handle(&mut self, msg: M, bus: &Bus) -> anyhow::Result<()>; async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { Ok(()) } @@ -83,7 +83,7 @@ pub trait LocalBatchHandler { } } -#[async_trait] +#[async_trait(?Send)] pub trait LocalAsyncBatchHandler { async fn handle(&mut self, msg: Vec, bus: &Bus) -> anyhow::Result<()>; async fn sync(&mut self, _bus: &Bus) -> anyhow::Result<()> { diff --git a/src/lib.rs b/src/lib.rs index ac7dfe7..20356a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod msgs; mod receiver; pub mod receivers; mod trait_object; +mod untyped; mod utils; use builder::BusBuilder; @@ -14,13 +15,13 @@ pub use receiver::SendError; use receiver::{Receiver, ReceiverStats}; use utils::binary_search_range_by_key; -use core::any::{Any, TypeId}; +use core::any::TypeId; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -pub type Untyped = Arc; +pub use untyped::Untyped; pub type Result = anyhow::Result<()>; pub struct BusInner { diff --git a/src/receivers/buffer_unordered/async.rs b/src/receivers/buffer_unordered/async.rs index 155296d..e9ec82c 100644 --- a/src/receivers/buffer_unordered/async.rs +++ b/src/receivers/buffer_unordered/async.rs @@ -77,7 +77,7 @@ async fn buffer_unordered_poller( T: AsyncHandler + 'static, M: Message, { - let ut = ut.downcast::().unwrap(); + let ut = ut.downcast_sync::().unwrap(); let mut x = rx .map(|msg| { @@ -86,7 +86,9 @@ async fn buffer_unordered_poller( let bus = bus.clone(); let ut = ut.clone(); - tokio::task::spawn(async move { ut.handle(msg, &bus).await }) + tokio::task::spawn( + async move { ut.lock_read().await.get_ref().handle(msg, &bus).await }, + ) }) .buffer_unordered(cfg.max_parallel); @@ -103,7 +105,9 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn(async move { ut.sync(&bus_clone).await }).await; + let res = + tokio::task::spawn(async move { ut.lock_read().await.get_ref().sync(&bus_clone).await }) + .await; match res { Ok(Err(err)) => { diff --git a/src/receivers/buffer_unordered/sync.rs b/src/receivers/buffer_unordered/sync.rs index d8955a9..2b732bd 100644 --- a/src/receivers/buffer_unordered/sync.rs +++ b/src/receivers/buffer_unordered/sync.rs @@ -1,5 +1,5 @@ use crate::{receiver::ReceiverStats, receivers::mpsc}; -use futures::{Future, StreamExt}; +use futures::{executor::block_on, Future, StreamExt}; use std::{ any::TypeId, marker::PhantomData, @@ -76,7 +76,7 @@ async fn buffer_unordered_poller( T: Handler + 'static, M: Message, { - let ut = ut.downcast::().unwrap(); + let ut = ut.downcast_sync::().unwrap(); let mut x = rx .map(|msg| { @@ -86,7 +86,9 @@ async fn buffer_unordered_poller( let bus = bus.clone(); let ut = ut.clone(); - tokio::task::spawn_blocking(move || ut.handle(msg, &bus)) + tokio::task::spawn_blocking(move || { + block_on(ut.lock_read()).get_ref().handle(msg, &bus) + }) }) .buffer_unordered(cfg.max_parallel); @@ -103,7 +105,9 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await; + let res = + tokio::task::spawn_blocking(move || block_on(ut.lock_read()).get_ref().sync(&bus_clone)) + .await; match res { Ok(Err(err)) => { @@ -112,10 +116,7 @@ async fn buffer_unordered_poller( _ => (), } - println!( - "[EXIT] BufferUnorderedSync<{}>", - std::any::type_name::() - ); + println!("[EXIT] BufferUnorderedSync<{}>", std::any::type_name::()); } pub struct BufferUnorderedSync { diff --git a/src/receivers/buffer_unordered_batched/async.rs b/src/receivers/buffer_unordered_batched/async.rs index 4b1cbc6..127c214 100644 --- a/src/receivers/buffer_unordered_batched/async.rs +++ b/src/receivers/buffer_unordered_batched/async.rs @@ -79,19 +79,16 @@ async fn buffer_unordered_poller( T: AsyncBatchHandler + 'static, M: Message, { - let ut = ut.downcast::().unwrap(); - let rx = rx - .inspect(|_| { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.batch.fetch_add(1, Ordering::Relaxed); - }); + let ut = ut.downcast_sync::().unwrap(); + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); let rx = if cfg.when_ready { - rx.ready_chunks(cfg.batch_size) - .left_stream() + rx.ready_chunks(cfg.batch_size).left_stream() } else { - rx.chunks(cfg.batch_size) - .right_stream() + rx.chunks(cfg.batch_size).right_stream() }; let mut rx = rx @@ -102,7 +99,13 @@ async fn buffer_unordered_poller( let bus_clone = bus.clone(); let ut = ut.clone(); - tokio::task::spawn(async move { ut.handle(msgs, &bus_clone).await }) + tokio::task::spawn(async move { + ut.lock_read() + .await + .get_ref() + .handle(msgs, &bus_clone) + .await + }) }) .buffer_unordered(cfg.max_parallel); @@ -119,7 +122,9 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn(async move { ut.sync(&bus_clone).await }).await; + let res = + tokio::task::spawn(async move { ut.lock_read().await.get_ref().sync(&bus_clone).await }) + .await; match res { Ok(Err(err)) => { diff --git a/src/receivers/buffer_unordered_batched/mod.rs b/src/receivers/buffer_unordered_batched/mod.rs index 1966b0e..b8ae7dd 100644 --- a/src/receivers/buffer_unordered_batched/mod.rs +++ b/src/receivers/buffer_unordered_batched/mod.rs @@ -27,9 +27,9 @@ pub struct BufferUnorderedBatchedConfig { impl Default for BufferUnorderedBatchedConfig { fn default() -> Self { Self { - buffer_size: 8, + buffer_size: 8, max_parallel: 2, - batch_size: 8, + batch_size: 8, when_ready: false, } } diff --git a/src/receivers/buffer_unordered_batched/sync.rs b/src/receivers/buffer_unordered_batched/sync.rs index 0eef451..7ad05fb 100644 --- a/src/receivers/buffer_unordered_batched/sync.rs +++ b/src/receivers/buffer_unordered_batched/sync.rs @@ -1,5 +1,5 @@ use crate::{receiver::ReceiverStats, receivers::mpsc}; -use futures::{Future, StreamExt}; +use futures::{executor::block_on, Future, StreamExt}; use std::{ any::TypeId, marker::PhantomData, @@ -16,7 +16,7 @@ use crate::{ builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, msgs, receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, - Bus, BatchHandler, Message, Untyped, + BatchHandler, Bus, Message, Untyped, }; pub struct BufferUnorderedBatchedSyncSubscriber @@ -78,19 +78,16 @@ async fn buffer_unordered_poller( T: BatchHandler + 'static, M: Message, { - let ut = ut.downcast::().unwrap(); - let rx = rx - .inspect(|_| { - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.batch.fetch_add(1, Ordering::Relaxed); - }); + let ut = ut.downcast_sync::().unwrap(); + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); let rx = if cfg.when_ready { - rx.ready_chunks(cfg.batch_size) - .left_stream() + rx.ready_chunks(cfg.batch_size).left_stream() } else { - rx.chunks(cfg.batch_size) - .right_stream() + rx.chunks(cfg.batch_size).right_stream() }; let mut rx = rx @@ -101,7 +98,9 @@ async fn buffer_unordered_poller( let bus = bus.clone(); let ut = ut.clone(); - tokio::task::spawn_blocking(move || ut.handle(msgs, &bus)) + tokio::task::spawn_blocking(move || { + block_on(ut.lock_read()).get_ref().handle(msgs, &bus) + }) }) .buffer_unordered(cfg.max_parallel); @@ -118,7 +117,9 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn_blocking(move || ut.sync(&bus_clone)).await; + let res = + tokio::task::spawn_blocking(move || block_on(ut.lock_read()).get_ref().sync(&bus_clone)) + .await; match res { Ok(Err(err)) => { diff --git a/src/receivers/local/async.rs b/src/receivers/local/async.rs new file mode 100644 index 0000000..6e5eceb --- /dev/null +++ b/src/receivers/local/async.rs @@ -0,0 +1,180 @@ +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{pin_mut, Future, StreamExt}; + +use super::{LocalConfig, LocalStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, LocalAsyncHandler, Message, Untyped, +}; + +pub struct LocalAsyncSubscriber +where + T: LocalAsyncHandler + 'static, + M: Message, +{ + cfg: LocalConfig, + _m: PhantomData<(T, M)>, +} + +impl ReceiverSubscriber for LocalAsyncSubscriber +where + T: LocalAsyncHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(LocalStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + }); + + let arc = Arc::new(LocalAsync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + _cfg: LocalConfig, +) where + T: LocalAsyncHandler + 'static, + M: Message, +{ + let ut = ut.downcast_local::().unwrap(); + let bus1 = bus.clone(); + + let x = rx.then(|msg| { + let bus1 = bus1.clone(); + ut.spawn_local(move |item| { + Box::pin(async move { + let _ = item.handle(msg, &bus1).await; + }) + }) + }); + + pin_mut!(x); + + while let Some(_) = x.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + } + + let bus_clone = bus.clone(); + ut.spawn_local(move |item| { + Box::pin(async move { + let _ = item.sync(&bus_clone).await; + }) + }) + .await; + + println!("[EXIT] LocalAsync<{}>", std::any::type_name::()); +} + +pub struct LocalAsync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for LocalAsync +where + T: LocalAsyncHandler + 'static, + M: Message, +{ + type Entry = LocalAsyncSubscriber; + type Config = LocalConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + LocalAsyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for LocalAsync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for LocalAsync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/local/mod.rs b/src/receivers/local/mod.rs new file mode 100644 index 0000000..e0535e1 --- /dev/null +++ b/src/receivers/local/mod.rs @@ -0,0 +1,25 @@ +mod r#async; +mod sync; + +use std::sync::atomic::AtomicU64; + +pub use sync::{LocalSync, LocalSyncSubscriber}; + +pub use r#async::{LocalAsync, LocalAsyncSubscriber}; + +#[derive(Debug)] +pub struct LocalStats { + pub buffer: AtomicU64, + pub buffer_total: AtomicU64, +} + +#[derive(Copy, Clone, Debug)] +pub struct LocalConfig { + pub buffer_size: usize, +} + +impl Default for LocalConfig { + fn default() -> Self { + Self { buffer_size: 1 } + } +} diff --git a/src/receivers/local/sync.rs b/src/receivers/local/sync.rs new file mode 100644 index 0000000..553bbc3 --- /dev/null +++ b/src/receivers/local/sync.rs @@ -0,0 +1,177 @@ +use super::{LocalConfig, LocalStats}; +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{pin_mut, Future, StreamExt}; +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, LocalHandler, Message, Untyped, +}; + +pub struct LocalSyncSubscriber +where + T: LocalHandler + 'static, + M: Message, +{ + cfg: LocalConfig, + _m: PhantomData<(M, T)>, +} + +impl ReceiverSubscriber for LocalSyncSubscriber +where + T: LocalHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(LocalStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + }); + + let arc = Arc::new(LocalSync:: { + tx, + stats: stats.clone(), + }); + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + _cfg: LocalConfig, +) where + T: LocalHandler + 'static, + M: Message, +{ + let ut = ut.downcast_local::().unwrap(); + let bus1 = bus.clone(); + + let x = rx.then(|msg| { + let bus1 = bus1.clone(); + ut.spawn_local(move |item| { + let _ = item.handle(msg, &bus1); + + Box::pin(async move {}) + }) + }); + + pin_mut!(x); + + while let Some(_) = x.next().await { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + } + + let bus_clone = bus.clone(); + ut.spawn_local(move |item| { + let _ = item.sync(&bus_clone); + Box::pin(async move {}) + }) + .await; + + println!("[EXIT] LocalAsync<{}>", std::any::type_name::()); +} + +pub struct LocalSync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for LocalSync +where + T: LocalHandler + 'static, + M: Message, +{ + type Entry = LocalSyncSubscriber; + type Config = LocalConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + LocalSyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for LocalSync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for LocalSync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::>() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/local_batched/async.rs b/src/receivers/local_batched/async.rs new file mode 100644 index 0000000..1a69673 --- /dev/null +++ b/src/receivers/local_batched/async.rs @@ -0,0 +1,196 @@ +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; + +use super::{LocalBatchedConfig, LocalBatchedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, LocalAsyncBatchHandler, Message, Untyped, +}; + +pub struct LocalBatchedAsyncSubscriber +where + T: LocalAsyncBatchHandler + 'static, + M: Message, +{ + cfg: LocalBatchedConfig, + _m: PhantomData<(T, M)>, +} + +impl ReceiverSubscriber for LocalBatchedAsyncSubscriber +where + T: LocalAsyncBatchHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(LocalBatchedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + batch: AtomicU64::new(0), + batch_size: AtomicU64::new(cfg.batch_size as _), + }); + + let arc = Arc::new(LocalBatchedAsync:: { + tx, + stats: stats.clone(), + }); + + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: LocalBatchedConfig, +) where + T: LocalAsyncBatchHandler + 'static, + M: Message, +{ + let ut = ut.downcast_local::().unwrap(); + + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); + + let mut rx = if cfg.when_ready { + rx.ready_chunks(cfg.batch_size).left_stream() + } else { + rx.chunks(cfg.batch_size).right_stream() + }; + + while let Some(msgs) = rx.next().await { + stats.batch.fetch_sub(msgs.len() as _, Ordering::Relaxed); + + let bus_clone = bus.clone(); + let ut = ut.clone(); + + ut.spawn_local(move |item| { + Box::pin(async move { + let _ = item.handle(msgs, &bus_clone).await; + }) + }) + .await; + } + + let bus_clone = bus.clone(); + ut.spawn_local(move |item| { + Box::pin(async move { + let _ = item.sync(&bus_clone).await; + }) + }) + .await; + + println!("[EXIT] LocalBatchedAsync<{}>", std::any::type_name::()); +} + +pub struct LocalBatchedAsync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for LocalBatchedAsync +where + T: LocalAsyncBatchHandler + 'static, + M: Message, +{ + type Entry = LocalBatchedAsyncSubscriber; + type Config = LocalBatchedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + LocalBatchedAsyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for LocalBatchedAsync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for LocalBatchedAsync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::() + } + + fn close(&self) { + self.tx.close(); + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ("batch".into(), self.stats.batch.load(Ordering::SeqCst)), + ( + "batch_size".into(), + self.stats.batch_size.load(Ordering::SeqCst), + ), + ], + } + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/local_batched/mod.rs b/src/receivers/local_batched/mod.rs new file mode 100644 index 0000000..c2e1240 --- /dev/null +++ b/src/receivers/local_batched/mod.rs @@ -0,0 +1,33 @@ +mod r#async; +mod sync; + +use std::sync::atomic::AtomicU64; + +pub use sync::{LocalBatchedSync, LocalBatchedSyncSubscriber}; + +pub use r#async::{LocalBatchedAsync, LocalBatchedAsyncSubscriber}; + +#[derive(Debug)] +pub struct LocalBatchedStats { + pub buffer: AtomicU64, + pub buffer_total: AtomicU64, + pub batch: AtomicU64, + pub batch_size: AtomicU64, +} + +#[derive(Copy, Clone, Debug)] +pub struct LocalBatchedConfig { + pub buffer_size: usize, + pub batch_size: usize, + pub when_ready: bool, +} + +impl Default for LocalBatchedConfig { + fn default() -> Self { + Self { + buffer_size: 4, + batch_size: 16, + when_ready: false, + } + } +} diff --git a/src/receivers/local_batched/sync.rs b/src/receivers/local_batched/sync.rs new file mode 100644 index 0000000..541b4bb --- /dev/null +++ b/src/receivers/local_batched/sync.rs @@ -0,0 +1,193 @@ +use super::{LocalBatchedConfig, LocalBatchedStats}; +use crate::{ + builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, + receiver::{AnyReceiver, ReceiverTrait, SendError, TypedReceiver}, + Bus, LocalBatchHandler, Message, Untyped, +}; +use crate::{receiver::ReceiverStats, receivers::mpsc}; +use futures::{Future, StreamExt}; +use std::{ + any::TypeId, + marker::PhantomData, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +pub struct LocalBatchedSyncSubscriber +where + T: LocalBatchHandler + 'static, + M: Message, +{ + cfg: LocalBatchedConfig, + _m: PhantomData<(M, T)>, +} + +impl ReceiverSubscriber for LocalBatchedSyncSubscriber +where + T: LocalBatchHandler + 'static, + M: Message, +{ + fn subscribe( + self, + ) -> ( + Arc, + Box< + dyn FnOnce(Untyped) -> Box Pin + Send>>>, + >, + ) { + let cfg = self.cfg; + let (tx, rx) = mpsc::channel(cfg.buffer_size); + let stats = Arc::new(LocalBatchedStats { + buffer: AtomicU64::new(0), + buffer_total: AtomicU64::new(cfg.buffer_size as _), + batch: AtomicU64::new(0), + batch_size: AtomicU64::new(cfg.batch_size as _), + }); + + let arc = Arc::new(LocalBatchedSync:: { + tx, + stats: stats.clone(), + }); + let poller = Box::new(move |ut| { + Box::new(move |bus| { + Box::pin(buffer_unordered_poller::(rx, bus, ut, stats, cfg)) + as Pin + Send>> + }) as Box Pin + Send>>> + }); + + (arc, poller) + } +} + +async fn buffer_unordered_poller( + rx: mpsc::Receiver, + bus: Bus, + ut: Untyped, + stats: Arc, + cfg: LocalBatchedConfig, +) where + T: LocalBatchHandler + 'static, + M: Message, +{ + let ut = ut.downcast_local::().unwrap(); + + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); + + let mut rx = if cfg.when_ready { + rx.ready_chunks(cfg.batch_size).left_stream() + } else { + rx.chunks(cfg.batch_size).right_stream() + }; + + while let Some(msgs) = rx.next().await { + stats.batch.fetch_sub(msgs.len() as _, Ordering::Relaxed); + + let bus_clone = bus.clone(); + let ut = ut.clone(); + + ut.spawn_local(move |item| { + let _ = item.handle(msgs, &bus_clone); + + Box::pin(async move {}) + }) + .await; + } + + let ut = ut.clone(); + let bus_clone = bus.clone(); + ut.spawn_local(move |item| { + let _ = item.sync(&bus_clone); + Box::pin(async move {}) + }) + .await; + + println!("[EXIT] LocalBatchedSync<{}>", std::any::type_name::()); +} + +pub struct LocalBatchedSync { + tx: mpsc::Sender, + stats: Arc, +} + +impl ReceiverSubscriberBuilder for LocalBatchedSync +where + T: LocalBatchHandler + 'static, + M: Message, +{ + type Entry = LocalBatchedSyncSubscriber; + type Config = LocalBatchedConfig; + + fn build(cfg: Self::Config) -> Self::Entry { + LocalBatchedSyncSubscriber { + cfg, + _m: Default::default(), + } + } +} + +impl TypedReceiver for LocalBatchedSync { + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<()> { + match self.tx.poll_ready(ctx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + fn try_send(&self, m: M) -> Result<(), SendError> { + match self.tx.try_send(m) { + Ok(_) => { + self.stats.buffer.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + Err(err) => Err(err), + } + } +} + +impl ReceiverTrait for LocalBatchedSync { + fn typed(&self) -> AnyReceiver<'_> { + AnyReceiver::new(self) + } + + fn type_id(&self) -> TypeId { + TypeId::of::() + } + + fn stats(&self) -> ReceiverStats { + ReceiverStats { + name: std::any::type_name::().into(), + fields: vec![ + ("buffer".into(), self.stats.buffer.load(Ordering::SeqCst)), + ( + "buffer_total".into(), + self.stats.buffer_total.load(Ordering::SeqCst), + ), + ("batch".into(), self.stats.batch.load(Ordering::SeqCst)), + ( + "batch_size".into(), + self.stats.batch_size.load(Ordering::SeqCst), + ), + ], + } + } + + fn close(&self) { + self.tx.close(); + } + + fn sync(&self) { + self.tx.flush(); + } + + fn poll_synchronized(&self, _ctx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/src/receivers/mod.rs b/src/receivers/mod.rs index a8c8c25..66995cb 100644 --- a/src/receivers/mod.rs +++ b/src/receivers/mod.rs @@ -1,5 +1,7 @@ mod buffer_unordered; mod buffer_unordered_batched; +mod local; +mod local_batched; mod mpsc_futures; mod synchronize_batched; mod synchronized; @@ -14,8 +16,8 @@ pub use buffer_unordered::{ }; pub use buffer_unordered_batched::{ - BufferUnorderedBatchedAsync, BufferUnorderedBatchedAsyncSubscriber, BufferUnorderedBatchedConfig, - BufferUnorderedBatchedSync, BufferUnorderedBatchedSyncSubscriber, + BufferUnorderedBatchedAsync, BufferUnorderedBatchedAsyncSubscriber, + BufferUnorderedBatchedConfig, BufferUnorderedBatchedSync, BufferUnorderedBatchedSyncSubscriber, }; pub use synchronized::{ @@ -27,3 +29,10 @@ pub use synchronize_batched::{ SynchronizeBatchedAsync, SynchronizeBatchedAsyncSubscriber, SynchronizeBatchedConfig, SynchronizeBatchedSync, SynchronizeBatchedSyncSubscriber, }; + +pub use local::{LocalAsync, LocalAsyncSubscriber, LocalConfig, LocalSync, LocalSyncSubscriber}; + +pub use local_batched::{ + LocalBatchedAsync, LocalBatchedAsyncSubscriber, LocalBatchedConfig, LocalBatchedSync, + LocalBatchedSyncSubscriber, +}; diff --git a/src/receivers/mpsc_futures.rs b/src/receivers/mpsc_futures.rs index 7014bb8..350ef56 100644 --- a/src/receivers/mpsc_futures.rs +++ b/src/receivers/mpsc_futures.rs @@ -99,7 +99,7 @@ impl Stream for Receiver { } Poll::Pending - }, + } } } } diff --git a/src/receivers/synchronize_batched/async.rs b/src/receivers/synchronize_batched/async.rs index 1a802dc..17f95f8 100644 --- a/src/receivers/synchronize_batched/async.rs +++ b/src/receivers/synchronize_batched/async.rs @@ -11,7 +11,6 @@ use std::{ use crate::{receiver::ReceiverStats, receivers::mpsc}; use futures::{Future, StreamExt}; -use tokio::sync::Mutex; use super::{SynchronizeBatchedConfig, SynchronizeBatchedStats}; use crate::{ @@ -78,20 +77,17 @@ async fn buffer_unordered_poller( T: AsyncBatchSynchronizedHandler + 'static, M: Message, { - let ut = ut.downcast::>().unwrap(); + let ut = ut.downcast_send::().unwrap(); - let rx = rx - .inspect(|_|{ - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.batch.fetch_add(1, Ordering::Relaxed); - }); + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); let mut rx = if cfg.when_ready { - rx.ready_chunks(cfg.batch_size) - .left_stream() + rx.ready_chunks(cfg.batch_size).left_stream() } else { - rx.chunks(cfg.batch_size) - .right_stream() + rx.chunks(cfg.batch_size).right_stream() }; while let Some(msgs) = rx.next().await { @@ -101,8 +97,7 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let res = - tokio::task::spawn(async move { ut.lock().await.handle(msgs, &bus_clone).await }) - .await; + tokio::task::spawn(async move { ut.lock().await.handle(msgs, &bus_clone).await }).await; match res { Ok(Err(err)) => { diff --git a/src/receivers/synchronize_batched/sync.rs b/src/receivers/synchronize_batched/sync.rs index f50042c..a97ea81 100644 --- a/src/receivers/synchronize_batched/sync.rs +++ b/src/receivers/synchronize_batched/sync.rs @@ -6,7 +6,7 @@ use crate::{ BatchSynchronizedHandler, Bus, Message, Untyped, }; use crate::{receiver::ReceiverStats, receivers::mpsc}; -use futures::{Future, StreamExt}; +use futures::{executor::block_on, Future, StreamExt}; use std::{ any::TypeId, marker::PhantomData, @@ -17,7 +17,6 @@ use std::{ }, task::{Context, Poll}, }; -use tokio::sync::Mutex; pub struct SynchronizeBatchedSyncSubscriber where @@ -75,20 +74,17 @@ async fn buffer_unordered_poller( T: BatchSynchronizedHandler + 'static, M: Message, { - let ut = ut.downcast::>().unwrap(); + let ut = ut.downcast_send::().unwrap(); - let rx = rx - .inspect(|_|{ - stats.buffer.fetch_sub(1, Ordering::Relaxed); - stats.batch.fetch_add(1, Ordering::Relaxed); - }); + let rx = rx.inspect(|_| { + stats.buffer.fetch_sub(1, Ordering::Relaxed); + stats.batch.fetch_add(1, Ordering::Relaxed); + }); let mut rx = if cfg.when_ready { - rx.ready_chunks(cfg.batch_size) - .left_stream() + rx.ready_chunks(cfg.batch_size).left_stream() } else { - rx.chunks(cfg.batch_size) - .right_stream() + rx.chunks(cfg.batch_size).right_stream() }; while let Some(msgs) = rx.next().await { @@ -98,10 +94,10 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let res = tokio::task::spawn_blocking(move || { - let mut uut = futures::executor::block_on(ut.lock()); - + let mut uut = block_on(ut.lock()); uut.handle(msgs, &bus_clone) - }).await; + }) + .await; match res { Ok(Err(err)) => { @@ -113,10 +109,7 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn_blocking(move || { - futures::executor::block_on(ut.lock()).sync(&bus_clone) - }) - .await; + let res = tokio::task::spawn_blocking(move || block_on(ut.lock()).sync(&bus_clone)).await; match res { Ok(Err(err)) => { diff --git a/src/receivers/synchronized/async.rs b/src/receivers/synchronized/async.rs index e4952d5..03def70 100644 --- a/src/receivers/synchronized/async.rs +++ b/src/receivers/synchronized/async.rs @@ -11,7 +11,6 @@ use std::{ use crate::{receiver::ReceiverStats, receivers::mpsc}; use futures::{Future, StreamExt}; -use tokio::sync::Mutex; use super::{SynchronizedConfig, SynchronizedStats}; use crate::{ @@ -76,10 +75,13 @@ async fn buffer_unordered_poller( T: AsyncSynchronizedHandler + 'static, M: Message, { - let ut = ut.downcast::>().unwrap(); - let mut x = rx.then(|msg| { - let bus = bus.clone(); - let ut = ut.clone(); + let ut = ut.downcast_send::().unwrap(); + let ut1 = ut.clone(); + let bus1 = bus.clone(); + + let mut x = rx.then(move |msg| { + let bus = bus1.clone(); + let ut = ut1.clone(); tokio::task::spawn(async move { ut.lock().await.handle(msg, &bus).await }) }); diff --git a/src/receivers/synchronized/sync.rs b/src/receivers/synchronized/sync.rs index 82047ab..6235449 100644 --- a/src/receivers/synchronized/sync.rs +++ b/src/receivers/synchronized/sync.rs @@ -11,7 +11,6 @@ use std::{ }, task::{Context, Poll}, }; -use tokio::sync::Mutex; use crate::{ builder::{ReceiverSubscriber, ReceiverSubscriberBuilder}, @@ -74,12 +73,15 @@ async fn buffer_unordered_poller( T: SynchronizedHandler + 'static, M: Message, { - let ut = ut.downcast::>().unwrap(); + let ut = ut.downcast_send::().unwrap(); let mut x = rx.then(|msg| { let ut = ut.clone(); let bus = bus.clone(); - tokio::task::spawn_blocking(move || block_on(ut.lock()).handle(msg, &bus)) + tokio::task::spawn_blocking(move || { + let mut uut = block_on(ut.lock()); + uut.handle(msg, &bus) + }) }); while let Some(err) = x.next().await { @@ -95,10 +97,7 @@ async fn buffer_unordered_poller( let ut = ut.clone(); let bus_clone = bus.clone(); - let res = tokio::task::spawn_blocking(move || { - futures::executor::block_on(ut.lock()).sync(&bus_clone) - }) - .await; + let res = tokio::task::spawn_blocking(move || block_on(ut.lock()).sync(&bus_clone)).await; match res { Ok(Err(err)) => { @@ -107,10 +106,7 @@ async fn buffer_unordered_poller( _ => (), } - println!( - "[EXIT] BufferUnorderedSync<{}>", - std::any::type_name::() - ); + println!("[EXIT] BufferUnorderedSync<{}>", std::any::type_name::()); } pub struct SynchronizedSync { diff --git a/src/untyped.rs b/src/untyped.rs new file mode 100644 index 0000000..4cc0565 --- /dev/null +++ b/src/untyped.rs @@ -0,0 +1,174 @@ +use core::any::Any; +use core::future::Future; +use core::pin::Pin; +use std::sync::Arc; + +use tokio::sync::mpsc; +use tokio::sync::Notify; +use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[derive(Debug)] +pub enum Lock<'a, T> { + ReadOnly(&'a T), + RwRead(RwLockReadGuard<'a, T>), + RwWrite(RwLockWriteGuard<'a, T>), + WriteOnly(MutexGuard<'a, T>), +} + +impl<'a, T> Lock<'a, T> { + pub fn get_ref(&self) -> &T { + match self { + Lock::ReadOnly(inner) => &inner, + Lock::RwRead(inner) => &inner, + Lock::RwWrite(inner) => &inner, + Lock::WriteOnly(inner) => &inner, + } + } + + pub fn get_mut(&mut self) -> &mut T { + match self { + Lock::ReadOnly(_) => panic!("!!"), + Lock::RwRead(_) => panic!("!!"), + Lock::RwWrite(inner) => &mut *inner, + Lock::WriteOnly(inner) => &mut *inner, + } + } +} + +pub enum Downcasted { + ReadOnly(Arc), + ReadWrite(Arc>), + WriteOnly(Arc>), +} + +impl Clone for Downcasted { + fn clone(&self) -> Self { + match self { + Downcasted::ReadOnly(inner) => Downcasted::ReadOnly(inner.clone()), + Downcasted::ReadWrite(inner) => Downcasted::ReadWrite(inner.clone()), + Downcasted::WriteOnly(inner) => Downcasted::WriteOnly(inner.clone()), + } + } +} + +impl Downcasted { + pub async fn lock_read(&self) -> Lock<'_, T> { + match self { + Downcasted::ReadOnly(inner) => Lock::ReadOnly(&inner), + Downcasted::ReadWrite(inner) => Lock::RwRead(inner.read().await), + Downcasted::WriteOnly(inner) => Lock::WriteOnly(inner.lock().await), + } + } + pub async fn lock_write(&self) -> Lock<'_, T> { + match self { + Downcasted::ReadOnly(_) => unimplemented!(), + Downcasted::ReadWrite(inner) => Lock::RwWrite(inner.write().await), + Downcasted::WriteOnly(inner) => Lock::WriteOnly(inner.lock().await), + } + } +} + +#[derive(Clone)] +pub struct Untyped { + inner: Arc, +} + +impl Untyped { + pub fn new_readonly(item: T) -> Self { + Self { + inner: Arc::new(item), + } + } + + pub fn new_rwlock(item: T) -> Self { + Self { + inner: Arc::new(RwLock::new(item)), + } + } + + pub fn new_mutex(item: T) -> Self { + Self { + inner: Arc::new(Mutex::new(item)), + } + } + + pub fn new_local T + Send + 'static>(f: F) -> Self { + Self { + inner: Arc::new(ThreadDedicated::new(f)), + } + } + + pub fn downcast_sync(self) -> Option> { + let item = match self.inner.clone().downcast::>() { + Ok(inner) => Downcasted::ReadWrite(inner), + Err(_) => return None, + }; + + Some(item) + } + + pub fn downcast_send1(self) -> Option> { + let item = match self.inner.clone().downcast::>() { + Ok(inner) => Downcasted::WriteOnly(inner), + Err(_) => return None, + }; + + Some(item) + } + + pub fn downcast_send(self) -> Option>> { + self.inner.clone().downcast::>().ok() + } + + #[inline] + pub fn downcast_local(self) -> Option>> { + self.inner.clone().downcast::>().ok() + } +} + +pub struct ThreadDedicated { + sender: mpsc::Sender< + Box FnOnce(&'a mut T) -> Pin + 'a>> + Send>, + >, + notify: Arc, +} + +impl ThreadDedicated { + pub fn new T + Send + 'static>(builder: F) -> Self { + let notify = Arc::new(Notify::new()); + let (sender, mut receiver) = mpsc::channel(1); + + let sender: mpsc::Sender< + Box FnOnce(&'a mut T) -> Pin + 'a>> + Send>, + > = sender; + let notify_clone = notify.clone(); + std::thread::spawn(move || { + futures::executor::block_on(async move { + let mut obj = builder(); + + loop { + let cb = match receiver.recv().await { + Some(x) => x, + None => break, + }; + + cb(&mut obj).await; + notify_clone.notify_one(); + } + }); + }); + + Self { sender, notify } + } + + pub async fn spawn_local< + F: for<'a> FnOnce(&'a mut T) -> Pin + 'a>> + Send + 'static, + >( + &self, + cb: F, + ) { + self.sender.send(Box::new(cb)).await.ok().unwrap(); + + self.notify.notified().await + } +}