From 0297882f471582303f193991a57e010f2a3f1fba Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Tue, 24 Oct 2023 14:23:22 +0400 Subject: [PATCH] Initial Commit --- .gitignore | 2 + Cargo.toml | 18 ++ examples/demo.rs | 40 +++ src/lib.rs | 745 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 805 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 examples/demo.rs create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0d8556d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "messagebus" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +boxcar = "0.2.3" +dashmap = "5.5.0" +futures = "0.3.28" +kanal = "0.1.0-pre8" +log = "0.4.20" +tokio = { version = "1.32.0", features = ["sync", "rt", "macros"] } + +[dev-dependencies] +async-stream = "0.3.5" +tokio = { version = "1.32.0", features = ["full"] } diff --git a/examples/demo.rs b/examples/demo.rs new file mode 100644 index 0000000..96d82f9 --- /dev/null +++ b/examples/demo.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use messagebus::{Bus, Error, Message}; + +#[derive(Debug, Clone)] +pub struct Msg(pub i32); +impl Message for Msg {} + +pub struct Processor { + state: i32, +} + +impl Processor { + pub async fn spawn(sid: u32) -> Result<(usize, Self), Error> { + Ok((4, Self { state: 0 })) + } + + pub async fn handler_msg(self: Arc, sid: u32, tid: u32, msg: Msg) -> Result<(), Error> { + Ok(()) + } + + pub async fn finalize_msg_handler(self: Arc, sid: u32) -> Result<(), Error> { + Ok(()) + } +} + +async fn run() { + let bus = Bus::new(); + bus.register( + 4, + Processor::spawn, + Processor::handler_msg, + Processor::finalize_msg_handler, + ); +} + +#[tokio::main] +async fn main() { + run().await +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0734343 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,745 @@ +#![feature(impl_trait_in_assoc_type)] + +use std::{ + any::{Any, TypeId}, + fmt, + pin::{pin, Pin}, + sync::{ + atomic::{AtomicI64, Ordering}, + Arc, + }, +}; + +use dashmap::DashMap; +use futures::{future::Either, Future, Stream, StreamExt}; +use tokio::sync::{mpsc, Notify}; + +trait AsyncReceiver { + type Item: Send + 'static; + type Error: std::error::Error + Send + Sync + 'static; + type Fut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn recv(&mut self) -> Self::Fut<'_>; +} + +trait AsyncSender: Clone { + type Item: Send + 'static; + type Error: std::error::Error + Send + Sync + 'static; + type Fut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn send(&self, item: Self::Item) -> Self::Fut<'_>; + fn try_send(&self, item: Self::Item) -> Result, Self::Error>; +} + +impl AsyncReceiver for kanal::AsyncReceiver { + type Item = I; + type Error = kanal::ReceiveError; + type Fut<'a> = kanal::ReceiveFuture<'a, I>; + + fn recv(&mut self) -> Self::Fut<'_> { + kanal::AsyncReceiver::recv(self) + } +} + +impl AsyncSender for kanal::AsyncSender { + type Item = I; + type Error = kanal::SendError; + type Fut<'a> = kanal::SendFuture<'a, Self::Item>; + + fn send(&self, item: Self::Item) -> Self::Fut<'_> { + kanal::AsyncSender::send(self, item) + } + + fn try_send(&self, item: Self::Item) -> Result, Self::Error> { + let mut item = Some(item); + kanal::AsyncSender::try_send_option(self, &mut item)?; + Ok(item) + } +} + +// type Sender = kanal::AsyncSender; +// type Receiver = kanal::AsyncReceiver; + +// fn channel(cap: usize) -> (Sender, Receiver) { +// kanal::bounded_async(cap) +// } + +type Sender = mpsc::Sender; +type Receiver = mpsc::Receiver; + +fn channel(cap: usize) -> (Sender, Receiver) { + mpsc::channel(cap) +} + +const DISPATCHER_STREAM_ID: u32 = u32::MAX; +const DEFAUL_STREAM_ID: u32 = u32::MAX - 1; + +const DISPATCHER_TASK_ID: u32 = u32::MAX; +const DEFAUL_TASK_ID: u32 = 0; + +#[derive(Debug)] +pub enum Error { + HandlerIsNotRegistered, + Aborted, + SendError(String), + // SendError(kanal::SendError), +} + +impl From for Error { + fn from(value: kanal::SendError) -> Self { + Self::SendError(format!("{}", value)) + } +} + +impl From> for Error { + fn from(value: mpsc::error::SendError) -> Self { + Self::SendError(format!("{}", value)) + } +} + +pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {} + +impl Message for u64 {} +impl Message for u32 {} +impl Message for () {} + +#[derive(Debug, Clone)] +pub struct Msg { + stream_id: u32, + task_id: u32, + inner: M, +} + +impl Msg { + pub fn stream_id(&self) -> u32 { + self.stream_id + } + + pub fn task_id(&self) -> u32 { + self.task_id + } + + pub fn into_inner(self) -> M { + self.inner + } +} + +#[derive(Default)] +struct TaskCounter { + counter: AtomicI64, + notify: Notify, +} + +impl TaskCounter { + #[inline] + pub async fn wait_last_one(&self) { + self.notify.notified().await + } +} + +struct BusSenders { + senders: boxcar::Vec>>, +} + +impl BusSenders { + pub fn new(sender: Sender>) -> Self { + Self { + senders: boxcar::vec![sender], + } + } + + async fn send(&self, msg: Msg) -> Result<(), Error> { + match self.senders.count() { + 0 => (), + 1 => self.senders[0].send(msg).await?, + 2 => { + let (r1, r2) = futures::future::join( + self.senders[0].send(msg.clone()), + self.senders[1].send(msg), + ) + .await; + + r1.or(r2)? + } + 3 => { + let (r1, r2, r3) = futures::future::join3( + self.senders[0].send(msg.clone()), + self.senders[1].send(msg.clone()), + self.senders[2].send(msg), + ) + .await; + + r1.or(r2).or(r3)? + } + 4 => { + let (r1, r2, r3, r4) = futures::future::join4( + self.senders[0].send(msg.clone()), + self.senders[1].send(msg.clone()), + self.senders[2].send(msg.clone()), + self.senders[3].send(msg), + ) + .await; + + r1.or(r2).or(r3).or(r4)? + } + _ => { + let vec = futures::future::join_all( + self.senders.iter().map(|(_, s)| s.send(msg.clone())), + ) + .await; + + vec.into_iter().find(Result::is_err).unwrap_or(Ok(()))? + } + }; + + Ok(()) + } +} + +pub trait IntoMessage: Send { + fn into_message(self) -> Option; +} + +impl IntoMessage for Option { + fn into_message(self) -> Option { + self + } +} + +impl IntoMessage for M { + fn into_message(self) -> Option { + Some(self) + } +} +pub trait ProducerBuilder { + type Context: Producer + 'static; + type BuildFut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn build<'a>(stream_id: u32) -> Self::BuildFut<'a>; +} + +pub trait HandlerBuilder { + type Context: Handler + 'static; + type BuildFut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn build<'a>(stream_id: u32) -> Self::BuildFut<'a>; +} + +pub trait Handler: Send + Sync + 'static { + type Result: Message; + type IntoMessage: IntoMessage; + type HandleFut<'a>: Future> + Send + 'a + where + Self: 'a; + + type FinalizeFut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn handle<'a>(&self, msg: M, stream_id: u32, task_id: u32) -> Self::HandleFut<'a>; + fn finalize<'a>(self, stream_id: u32) -> Self::FinalizeFut<'a>; +} + +pub trait Producer: Send + 'static { + type Item: Message; + type IntoMessage: IntoMessage; + type Stream<'a>: Stream> + Send + Sync + 'a + where + Self: 'a; + + type FinalizeFut<'a>: Future> + Send + 'a + where + Self: 'a; + + fn stream(&mut self, msg: M, stream_id: u32) -> Self::Stream<'_>; + fn finalize<'a>(self, stream_id: u32) -> Self::FinalizeFut<'a>; +} + +#[derive(Default)] +struct BusInner { + senders: DashMap<(u32, u32, TypeId), Arc>, + contexts: DashMap<(u32, TypeId), (usize, Arc)>, + producers_tasks: TaskCounter, + handlers_tasks: TaskCounter, + producers_stop_notify: Notify, + producers_abort_notify: Notify, + abort_notify: Notify, +} + +impl BusInner { + async fn task_dispatcher< + M: Message, + R: Message, + C: Send + Sync + 'static, + F: Future> + Send + 'static, + H: Future, Error>> + Send + 'static, + T: Future> + Send + 'static, + >( + self: Arc, + mut rx: Receiver>, + builder: impl Fn(u32) -> F + Send + 'static, + handler: impl Fn(Arc, u32, u32, M) -> H + Clone + Send + 'static, + finalize: impl Fn(Arc, u32) -> T + Clone + Send + 'static, + ) -> Result<(), Error> { + let tid = TypeId::of::(); + let mut notified = pin!(self.abort_notify.notified()); + + while let Either::Left((Some(msg), _)) = + futures::future::select(pin!(rx.recv()), notified.as_mut()).await + { + if msg.stream_id == DISPATCHER_STREAM_ID { + log::warn!("ERROR: GOT GATEWAY IN STREAM_ID!!!"); + continue; + } + + if !self + .senders + .contains_key(&(msg.stream_id, msg.task_id, tid)) + { + let (queue, ctx) = if let Some(rec) = self.contexts.get(&(msg.stream_id, tid)) { + let (queue, ctx) = rec.value(); + + (*queue, ctx.clone().downcast::().unwrap()) + } else { + let fut = (builder)(msg.stream_id); + + let (queue, ctx) = match fut.await { + Ok(tpl) => tpl, + Err(err) => { + println!("BUILDER ERROR: {:?}", err); + continue; + } + }; + let ctx = Arc::new(ctx); + + self.contexts + .insert((msg.stream_id, tid), (queue, ctx.clone())); + + (queue, ctx) + }; + + let (tx, rx) = channel::>(queue); + + let stream_id = msg.stream_id; + let task_id = msg.task_id; + + let handler = handler.clone(); + let finalize = finalize.clone(); + let self_clone = self.clone(); + + tokio::spawn(async move { + let tsk_cnt = &self_clone.handlers_tasks; + tsk_cnt.counter.fetch_add(1, Ordering::Relaxed); + + if let Err(err) = self_clone + .clone() + .task_handler(rx, ctx, handler, finalize) + .await + { + println!("error: {:?}", err); + } + + if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 { + tsk_cnt.notify.notify_one(); + } + }); + + tx.send(msg).await.unwrap(); + + self.senders.insert( + (stream_id, task_id, tid), + Arc::new(BusSenders::new(tx)) as _, + ); + } else { + let senders = self + .senders + .get(&(msg.stream_id, msg.task_id, tid)) + .unwrap() + .clone(); + + senders + .downcast_ref::>() + .unwrap() + .send(msg) + .await + .unwrap(); + } + } + + println!("DISPATCHER ENDED {}", std::any::type_name::()); + + Ok(()) + } + + async fn task_handler< + M: Message, + R: Message, + C: Send + Sync + 'static, + H: Future, Error>> + Send + 'static, + T: Future> + Send + 'static, + >( + self: Arc, + mut rx: Receiver>, + ctx: Arc, + handler: impl Fn(Arc, u32, u32, M) -> H + Clone + Send + 'static, + finalize: impl Fn(Arc, u32) -> T + Clone + Send + 'static, + ) -> Result<(), Error> { + let mut notified = pin!(self.abort_notify.notified()); + + while let Either::Left((Some(msg), _)) = + futures::future::select(pin!(rx.recv()), notified.as_mut()).await + { + let stream_id = msg.stream_id; + let fut = (handler)(ctx.clone(), stream_id, msg.task_id, msg.into_inner()); + let res = fut.await.unwrap(); + + if let Some(m) = res.into_message() { + if m.type_id() != TypeId::of::<()>() { + if let Err(err) = self.send(stream_id, m).await { + println!("BUS SEND ERROR: {:?}", err); + continue; + } + } + } + } + + println!("HANDLER ENDED {}", std::any::type_name::()); + + let fut = finalize(ctx, 0); + fut.await + } + + async fn producer_task< + M: Message, + R: Message, + S: Stream> + Send + 'static, + F: Future> + Send + 'static, + T: Future> + Send + 'static, + >( + self: Arc, + mut rx: Receiver>, + builder: impl Fn(u32, M) -> F + Send + 'static, + finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static, + ) -> Result<(), Error> { + let mut notified = pin!(self.abort_notify.notified()); + + while let Either::Left((Some(msg), _)) = + futures::future::select(pin!(rx.recv()), notified.as_mut()).await + { + let stream_id = msg.stream_id; + let fut = (builder)(stream_id, msg.into_inner()); + let mut stream = pin!(fut.await.unwrap()); + + let streaming_task = async { + loop { + match stream.next().await { + Some(Ok(msg)) => { + if let Err(err) = self.send(stream_id, msg).await { + println!("BUS SEND ERROR: {:?}", err); + continue; + } + } + Some(Err(err)) => { + println!("PRODUCER ERROR: {:?}", err); + continue; + } + None => { + println!( + "PRODUCER DRAINED THE STREAM {} of type {}", + stream_id, + std::any::type_name::() + ); + break; + } + } + } + }; + + let aborted = tokio::select! { + _ = streaming_task => false, + _ = self.producers_abort_notify.notified() => { + println!("ABORTED: {}", stream_id); + true + } + }; + + let fut = finalize(stream_id, stream); + if let Err(err) = fut.await { + println!("PRODUCER FINALIZE ERROR: {:?}", err); + } + + if aborted { + break; + } + } + + println!( + "PRODUCER TASK ENDED FOR TYPE {}", + std::any::type_name::() + ); + + Ok(()) + } + + pub async fn send(&self, stream_id: u32, msg: M) -> Result<(), Error> { + let task_id = DEFAUL_TASK_ID; + + let sender_ref = self + .senders + .get(&(stream_id, task_id, TypeId::of::())) + .or_else(|| { + self.senders + .get(&(DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::())) + }); + + let Some(sender_ref) = sender_ref else { + return Err(Error::HandlerIsNotRegistered); + }; + + let sender = sender_ref.clone(); + drop(sender_ref); + + sender + .downcast_ref::>() + .unwrap() + .send(Msg { + stream_id, + task_id, + inner: msg, + }) + .await?; + + Ok(()) + } + + pub fn register_dispatcher< + C: Send + Sync + 'static, + M: Message, + R: Message, + F: Future> + Send + 'static, + H: Future, Error>> + Send + 'static, + T: Future> + Send + 'static, + >( + &self, + bus: &Bus, + queue_size: usize, + builder: impl Fn(u32) -> F + Send + 'static, + handler: impl Fn(Arc, u32, u32, M) -> H + Clone + Send + 'static, + finalize: impl Fn(Arc, u32) -> T + Clone + Send + 'static, + ) { + let (tx, rx) = channel::>(queue_size); + + self.senders + .entry((DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::())) + .or_insert_with(|| Arc::new(BusSenders::new(tx))); + + let bus = bus.clone(); + tokio::spawn(async move { + let tsk_cnt = &bus.inner.handlers_tasks; + tsk_cnt.counter.fetch_add(1, Ordering::Relaxed); + + if let Err(err) = bus + .inner + .clone() + .task_dispatcher(rx, builder, handler, finalize) + .await + { + println!("error: {:?}", err); + } + + if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 { + tsk_cnt.notify.notify_one(); + } + }); + } + + pub fn register_producer< + M: Message, + I: Message, + S: Stream> + Send + 'static, + F: Future> + Send + 'static, + T: Future> + Send + 'static, + >( + &self, + bus: &Bus, + queue_size: usize, + builder: impl Fn(u32, M) -> F + Send + 'static, + finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static, + ) { + let (tx, rx) = channel::>(queue_size); + + self.senders + .entry((DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::())) + .or_insert_with(|| Arc::new(BusSenders::new(tx))); + + let bus = bus.clone(); + tokio::spawn(async move { + let tsk_cnt = &bus.inner.producers_tasks; + + tsk_cnt.counter.fetch_add(1, Ordering::Relaxed); + + if let Err(err) = bus.inner.clone().producer_task(rx, builder, finalize).await { + println!("error: {:?}", err); + } + + if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 { + tsk_cnt.notify.notify_one(); + } + }); + } + + pub async fn close(&self, force: bool) { + self.producers_abort_notify.notify_waiters(); + + if force { + self.abort_notify.notify_waiters(); + } + } + + pub async fn wait(&self) { + self.producers_stop_notify.notify_waiters(); + self.producers_tasks.wait_last_one().await; + println!("producers done"); + + self.abort_notify.notify_waiters(); + self.handlers_tasks.wait_last_one().await; + println!("handlers done"); + } +} + +#[derive(Default, Clone)] +pub struct Bus { + inner: Arc, +} + +impl Bus { + pub fn new() -> Self { + Self { + inner: Arc::new(BusInner::default()), + } + } + + #[inline] + pub fn register< + C: Send + Sync + 'static, + M: Message, + R: Message, + F: Future> + Send + 'static, + H: Future, Error>> + Send + 'static, + T: Future> + Send + 'static, + >( + &self, + queue: usize, + builder: impl Fn(u32) -> F + Send + 'static, + handler: impl Fn(Arc, u32, u32, M) -> H + Clone + Send + 'static, + finalize: impl Fn(Arc, u32) -> T + Clone + Send + 'static, + ) -> &Self { + self.inner + .register_dispatcher(self, queue, builder, handler, finalize); + + self + } + + #[inline] + pub fn register_producer< + M: Message, + I: Message, + S: Stream> + Send + 'static, + F: Future> + Send + 'static, + T: Future> + Send + 'static, + >( + &self, + queue_size: usize, + builder: impl Fn(u32, M) -> F + Send + 'static, + finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static, + ) -> &Self { + self.inner + .register_producer(self, queue_size, builder, finalize); + self + } + + #[inline] + pub async fn send(&self, msg: M) -> Result<(), Error> { + self.inner.send(DEFAUL_STREAM_ID, msg).await + } + + #[inline] + pub async fn send_with_stream(&self, stream_id: u32, msg: M) -> Result<(), Error> { + self.inner.send(stream_id, msg).await + } + + /// + /// Terminating everything ASAP + #[inline] + pub async fn shutdown(&self) { + self.inner.close(true).await; + } + + /// + /// Closing providers and waiting when queues were drained + #[inline] + pub async fn close(&self) { + self.inner.close(false).await; + } + + /// + /// Waiting the processing been done + #[inline] + pub fn wait(&self) -> impl Future + '_ { + self.inner.wait() + } +} + +#[cfg(test)] +mod tests { + use crate::Bus; + use async_stream::stream; + + #[tokio::test] + async fn test() { + let bus = Bus::default(); + + bus.register( + 1, + move |_| async move { Ok((1, ())) }, + move |_, sid, _, msg: u64| async move { + println!("MSG: {} {}", sid, msg); + Ok(()) + }, + |_, s| async move { + println!("handler {} finalized", s); + + Ok(()) + }, + ); + + bus.register_producer( + 1, + move |_, _msg: u32| async move { + Ok(stream! { + for i in 0u64..10 { + yield Ok( i) + } + }) + }, + |s, _| async move { + println!("producer {} finalized", s); + Ok(()) + }, + ); + + for start in 0u32..10 { + bus.send_with_stream(start, start).await.unwrap(); + } + + bus.wait().await; + } +}