Initial Commit

This commit is contained in:
Andrey Tkachenko 2023-10-24 14:23:22 +04:00
commit 0297882f47
4 changed files with 805 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
/Cargo.lock

18
Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "messagebus"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
boxcar = "0.2.3"
dashmap = "5.5.0"
futures = "0.3.28"
kanal = "0.1.0-pre8"
log = "0.4.20"
tokio = { version = "1.32.0", features = ["sync", "rt", "macros"] }
[dev-dependencies]
async-stream = "0.3.5"
tokio = { version = "1.32.0", features = ["full"] }

40
examples/demo.rs Normal file
View File

@ -0,0 +1,40 @@
use std::sync::Arc;
use messagebus::{Bus, Error, Message};
#[derive(Debug, Clone)]
pub struct Msg(pub i32);
impl Message for Msg {}
pub struct Processor {
state: i32,
}
impl Processor {
pub async fn spawn(sid: u32) -> Result<(usize, Self), Error> {
Ok((4, Self { state: 0 }))
}
pub async fn handler_msg(self: Arc<Self>, sid: u32, tid: u32, msg: Msg) -> Result<(), Error> {
Ok(())
}
pub async fn finalize_msg_handler(self: Arc<Self>, sid: u32) -> Result<(), Error> {
Ok(())
}
}
async fn run() {
let bus = Bus::new();
bus.register(
4,
Processor::spawn,
Processor::handler_msg,
Processor::finalize_msg_handler,
);
}
#[tokio::main]
async fn main() {
run().await
}

745
src/lib.rs Normal file
View File

@ -0,0 +1,745 @@
#![feature(impl_trait_in_assoc_type)]
use std::{
any::{Any, TypeId},
fmt,
pin::{pin, Pin},
sync::{
atomic::{AtomicI64, Ordering},
Arc,
},
};
use dashmap::DashMap;
use futures::{future::Either, Future, Stream, StreamExt};
use tokio::sync::{mpsc, Notify};
trait AsyncReceiver {
type Item: Send + 'static;
type Error: std::error::Error + Send + Sync + 'static;
type Fut<'a>: Future<Output = Result<Self::Item, Self::Error>> + Send + 'a
where
Self: 'a;
fn recv(&mut self) -> Self::Fut<'_>;
}
trait AsyncSender: Clone {
type Item: Send + 'static;
type Error: std::error::Error + Send + Sync + 'static;
type Fut<'a>: Future<Output = Result<(), Self::Error>> + Send + 'a
where
Self: 'a;
fn send(&self, item: Self::Item) -> Self::Fut<'_>;
fn try_send(&self, item: Self::Item) -> Result<Option<Self::Item>, Self::Error>;
}
impl<I: Send + 'static> AsyncReceiver for kanal::AsyncReceiver<I> {
type Item = I;
type Error = kanal::ReceiveError;
type Fut<'a> = kanal::ReceiveFuture<'a, I>;
fn recv(&mut self) -> Self::Fut<'_> {
kanal::AsyncReceiver::recv(self)
}
}
impl<I: Send + 'static> AsyncSender for kanal::AsyncSender<I> {
type Item = I;
type Error = kanal::SendError;
type Fut<'a> = kanal::SendFuture<'a, Self::Item>;
fn send(&self, item: Self::Item) -> Self::Fut<'_> {
kanal::AsyncSender::send(self, item)
}
fn try_send(&self, item: Self::Item) -> Result<Option<Self::Item>, Self::Error> {
let mut item = Some(item);
kanal::AsyncSender::try_send_option(self, &mut item)?;
Ok(item)
}
}
// type Sender<T> = kanal::AsyncSender<T>;
// type Receiver<T> = kanal::AsyncReceiver<T>;
// fn channel<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
// kanal::bounded_async(cap)
// }
type Sender<T> = mpsc::Sender<T>;
type Receiver<T> = mpsc::Receiver<T>;
fn channel<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
mpsc::channel(cap)
}
const DISPATCHER_STREAM_ID: u32 = u32::MAX;
const DEFAUL_STREAM_ID: u32 = u32::MAX - 1;
const DISPATCHER_TASK_ID: u32 = u32::MAX;
const DEFAUL_TASK_ID: u32 = 0;
#[derive(Debug)]
pub enum Error {
HandlerIsNotRegistered,
Aborted,
SendError(String),
// SendError(kanal::SendError),
}
impl From<kanal::SendError> for Error {
fn from(value: kanal::SendError) -> Self {
Self::SendError(format!("{}", value))
}
}
impl<M> From<mpsc::error::SendError<M>> for Error {
fn from(value: mpsc::error::SendError<M>) -> Self {
Self::SendError(format!("{}", value))
}
}
pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {}
impl Message for u64 {}
impl Message for u32 {}
impl Message for () {}
#[derive(Debug, Clone)]
pub struct Msg<M: Message> {
stream_id: u32,
task_id: u32,
inner: M,
}
impl<M: Message> Msg<M> {
pub fn stream_id(&self) -> u32 {
self.stream_id
}
pub fn task_id(&self) -> u32 {
self.task_id
}
pub fn into_inner(self) -> M {
self.inner
}
}
#[derive(Default)]
struct TaskCounter {
counter: AtomicI64,
notify: Notify,
}
impl TaskCounter {
#[inline]
pub async fn wait_last_one(&self) {
self.notify.notified().await
}
}
struct BusSenders<M: Message> {
senders: boxcar::Vec<Sender<Msg<M>>>,
}
impl<M: Message> BusSenders<M> {
pub fn new(sender: Sender<Msg<M>>) -> Self {
Self {
senders: boxcar::vec![sender],
}
}
async fn send(&self, msg: Msg<M>) -> Result<(), Error> {
match self.senders.count() {
0 => (),
1 => self.senders[0].send(msg).await?,
2 => {
let (r1, r2) = futures::future::join(
self.senders[0].send(msg.clone()),
self.senders[1].send(msg),
)
.await;
r1.or(r2)?
}
3 => {
let (r1, r2, r3) = futures::future::join3(
self.senders[0].send(msg.clone()),
self.senders[1].send(msg.clone()),
self.senders[2].send(msg),
)
.await;
r1.or(r2).or(r3)?
}
4 => {
let (r1, r2, r3, r4) = futures::future::join4(
self.senders[0].send(msg.clone()),
self.senders[1].send(msg.clone()),
self.senders[2].send(msg.clone()),
self.senders[3].send(msg),
)
.await;
r1.or(r2).or(r3).or(r4)?
}
_ => {
let vec = futures::future::join_all(
self.senders.iter().map(|(_, s)| s.send(msg.clone())),
)
.await;
vec.into_iter().find(Result::is_err).unwrap_or(Ok(()))?
}
};
Ok(())
}
}
pub trait IntoMessage<M>: Send {
fn into_message(self) -> Option<M>;
}
impl<M: Message> IntoMessage<M> for Option<M> {
fn into_message(self) -> Option<M> {
self
}
}
impl<M: Message> IntoMessage<M> for M {
fn into_message(self) -> Option<M> {
Some(self)
}
}
pub trait ProducerBuilder<M: Message> {
type Context: Producer<M> + 'static;
type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a
where
Self: 'a;
fn build<'a>(stream_id: u32) -> Self::BuildFut<'a>;
}
pub trait HandlerBuilder<M: Message> {
type Context: Handler<M> + 'static;
type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a
where
Self: 'a;
fn build<'a>(stream_id: u32) -> Self::BuildFut<'a>;
}
pub trait Handler<M: Message>: Send + Sync + 'static {
type Result: Message;
type IntoMessage: IntoMessage<Self::Result>;
type HandleFut<'a>: Future<Output = Result<Self::IntoMessage, Error>> + Send + 'a
where
Self: 'a;
type FinalizeFut<'a>: Future<Output = Result<(), Error>> + Send + 'a
where
Self: 'a;
fn handle<'a>(&self, msg: M, stream_id: u32, task_id: u32) -> Self::HandleFut<'a>;
fn finalize<'a>(self, stream_id: u32) -> Self::FinalizeFut<'a>;
}
pub trait Producer<M: Message>: Send + 'static {
type Item: Message;
type IntoMessage: IntoMessage<Self::Item>;
type Stream<'a>: Stream<Item = Result<Self::IntoMessage, Error>> + Send + Sync + 'a
where
Self: 'a;
type FinalizeFut<'a>: Future<Output = Result<(), Error>> + Send + 'a
where
Self: 'a;
fn stream(&mut self, msg: M, stream_id: u32) -> Self::Stream<'_>;
fn finalize<'a>(self, stream_id: u32) -> Self::FinalizeFut<'a>;
}
#[derive(Default)]
struct BusInner {
senders: DashMap<(u32, u32, TypeId), Arc<dyn Any + Send + Sync>>,
contexts: DashMap<(u32, TypeId), (usize, Arc<dyn Any + Send + Sync + 'static>)>,
producers_tasks: TaskCounter,
handlers_tasks: TaskCounter,
producers_stop_notify: Notify,
producers_abort_notify: Notify,
abort_notify: Notify,
}
impl BusInner {
async fn task_dispatcher<
M: Message,
R: Message,
C: Send + Sync + 'static,
F: Future<Output = Result<(usize, C), Error>> + Send + 'static,
H: Future<Output = Result<impl IntoMessage<R>, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
self: Arc<Self>,
mut rx: Receiver<Msg<M>>,
builder: impl Fn(u32) -> F + Send + 'static,
handler: impl Fn(Arc<C>, u32, u32, M) -> H + Clone + Send + 'static,
finalize: impl Fn(Arc<C>, u32) -> T + Clone + Send + 'static,
) -> Result<(), Error> {
let tid = TypeId::of::<M>();
let mut notified = pin!(self.abort_notify.notified());
while let Either::Left((Some(msg), _)) =
futures::future::select(pin!(rx.recv()), notified.as_mut()).await
{
if msg.stream_id == DISPATCHER_STREAM_ID {
log::warn!("ERROR: GOT GATEWAY IN STREAM_ID!!!");
continue;
}
if !self
.senders
.contains_key(&(msg.stream_id, msg.task_id, tid))
{
let (queue, ctx) = if let Some(rec) = self.contexts.get(&(msg.stream_id, tid)) {
let (queue, ctx) = rec.value();
(*queue, ctx.clone().downcast::<C>().unwrap())
} else {
let fut = (builder)(msg.stream_id);
let (queue, ctx) = match fut.await {
Ok(tpl) => tpl,
Err(err) => {
println!("BUILDER ERROR: {:?}", err);
continue;
}
};
let ctx = Arc::new(ctx);
self.contexts
.insert((msg.stream_id, tid), (queue, ctx.clone()));
(queue, ctx)
};
let (tx, rx) = channel::<Msg<M>>(queue);
let stream_id = msg.stream_id;
let task_id = msg.task_id;
let handler = handler.clone();
let finalize = finalize.clone();
let self_clone = self.clone();
tokio::spawn(async move {
let tsk_cnt = &self_clone.handlers_tasks;
tsk_cnt.counter.fetch_add(1, Ordering::Relaxed);
if let Err(err) = self_clone
.clone()
.task_handler(rx, ctx, handler, finalize)
.await
{
println!("error: {:?}", err);
}
if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 {
tsk_cnt.notify.notify_one();
}
});
tx.send(msg).await.unwrap();
self.senders.insert(
(stream_id, task_id, tid),
Arc::new(BusSenders::new(tx)) as _,
);
} else {
let senders = self
.senders
.get(&(msg.stream_id, msg.task_id, tid))
.unwrap()
.clone();
senders
.downcast_ref::<BusSenders<M>>()
.unwrap()
.send(msg)
.await
.unwrap();
}
}
println!("DISPATCHER ENDED {}", std::any::type_name::<M>());
Ok(())
}
async fn task_handler<
M: Message,
R: Message,
C: Send + Sync + 'static,
H: Future<Output = Result<impl IntoMessage<R>, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
self: Arc<Self>,
mut rx: Receiver<Msg<M>>,
ctx: Arc<C>,
handler: impl Fn(Arc<C>, u32, u32, M) -> H + Clone + Send + 'static,
finalize: impl Fn(Arc<C>, u32) -> T + Clone + Send + 'static,
) -> Result<(), Error> {
let mut notified = pin!(self.abort_notify.notified());
while let Either::Left((Some(msg), _)) =
futures::future::select(pin!(rx.recv()), notified.as_mut()).await
{
let stream_id = msg.stream_id;
let fut = (handler)(ctx.clone(), stream_id, msg.task_id, msg.into_inner());
let res = fut.await.unwrap();
if let Some(m) = res.into_message() {
if m.type_id() != TypeId::of::<()>() {
if let Err(err) = self.send(stream_id, m).await {
println!("BUS SEND ERROR: {:?}", err);
continue;
}
}
}
}
println!("HANDLER ENDED {}", std::any::type_name::<M>());
let fut = finalize(ctx, 0);
fut.await
}
async fn producer_task<
M: Message,
R: Message,
S: Stream<Item = Result<R, Error>> + Send + 'static,
F: Future<Output = Result<S, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
self: Arc<Self>,
mut rx: Receiver<Msg<M>>,
builder: impl Fn(u32, M) -> F + Send + 'static,
finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static,
) -> Result<(), Error> {
let mut notified = pin!(self.abort_notify.notified());
while let Either::Left((Some(msg), _)) =
futures::future::select(pin!(rx.recv()), notified.as_mut()).await
{
let stream_id = msg.stream_id;
let fut = (builder)(stream_id, msg.into_inner());
let mut stream = pin!(fut.await.unwrap());
let streaming_task = async {
loop {
match stream.next().await {
Some(Ok(msg)) => {
if let Err(err) = self.send(stream_id, msg).await {
println!("BUS SEND ERROR: {:?}", err);
continue;
}
}
Some(Err(err)) => {
println!("PRODUCER ERROR: {:?}", err);
continue;
}
None => {
println!(
"PRODUCER DRAINED THE STREAM {} of type {}",
stream_id,
std::any::type_name::<R>()
);
break;
}
}
}
};
let aborted = tokio::select! {
_ = streaming_task => false,
_ = self.producers_abort_notify.notified() => {
println!("ABORTED: {}", stream_id);
true
}
};
let fut = finalize(stream_id, stream);
if let Err(err) = fut.await {
println!("PRODUCER FINALIZE ERROR: {:?}", err);
}
if aborted {
break;
}
}
println!(
"PRODUCER TASK ENDED FOR TYPE {}",
std::any::type_name::<M>()
);
Ok(())
}
pub async fn send<M: Message>(&self, stream_id: u32, msg: M) -> Result<(), Error> {
let task_id = DEFAUL_TASK_ID;
let sender_ref = self
.senders
.get(&(stream_id, task_id, TypeId::of::<M>()))
.or_else(|| {
self.senders
.get(&(DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::<M>()))
});
let Some(sender_ref) = sender_ref else {
return Err(Error::HandlerIsNotRegistered);
};
let sender = sender_ref.clone();
drop(sender_ref);
sender
.downcast_ref::<BusSenders<M>>()
.unwrap()
.send(Msg {
stream_id,
task_id,
inner: msg,
})
.await?;
Ok(())
}
pub fn register_dispatcher<
C: Send + Sync + 'static,
M: Message,
R: Message,
F: Future<Output = Result<(usize, C), Error>> + Send + 'static,
H: Future<Output = Result<impl IntoMessage<R>, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
&self,
bus: &Bus,
queue_size: usize,
builder: impl Fn(u32) -> F + Send + 'static,
handler: impl Fn(Arc<C>, u32, u32, M) -> H + Clone + Send + 'static,
finalize: impl Fn(Arc<C>, u32) -> T + Clone + Send + 'static,
) {
let (tx, rx) = channel::<Msg<M>>(queue_size);
self.senders
.entry((DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::<M>()))
.or_insert_with(|| Arc::new(BusSenders::new(tx)));
let bus = bus.clone();
tokio::spawn(async move {
let tsk_cnt = &bus.inner.handlers_tasks;
tsk_cnt.counter.fetch_add(1, Ordering::Relaxed);
if let Err(err) = bus
.inner
.clone()
.task_dispatcher(rx, builder, handler, finalize)
.await
{
println!("error: {:?}", err);
}
if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 {
tsk_cnt.notify.notify_one();
}
});
}
pub fn register_producer<
M: Message,
I: Message,
S: Stream<Item = Result<I, Error>> + Send + 'static,
F: Future<Output = Result<S, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
&self,
bus: &Bus,
queue_size: usize,
builder: impl Fn(u32, M) -> F + Send + 'static,
finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static,
) {
let (tx, rx) = channel::<Msg<M>>(queue_size);
self.senders
.entry((DISPATCHER_STREAM_ID, DISPATCHER_TASK_ID, TypeId::of::<M>()))
.or_insert_with(|| Arc::new(BusSenders::new(tx)));
let bus = bus.clone();
tokio::spawn(async move {
let tsk_cnt = &bus.inner.producers_tasks;
tsk_cnt.counter.fetch_add(1, Ordering::Relaxed);
if let Err(err) = bus.inner.clone().producer_task(rx, builder, finalize).await {
println!("error: {:?}", err);
}
if tsk_cnt.counter.fetch_sub(1, Ordering::Relaxed) == 1 {
tsk_cnt.notify.notify_one();
}
});
}
pub async fn close(&self, force: bool) {
self.producers_abort_notify.notify_waiters();
if force {
self.abort_notify.notify_waiters();
}
}
pub async fn wait(&self) {
self.producers_stop_notify.notify_waiters();
self.producers_tasks.wait_last_one().await;
println!("producers done");
self.abort_notify.notify_waiters();
self.handlers_tasks.wait_last_one().await;
println!("handlers done");
}
}
#[derive(Default, Clone)]
pub struct Bus {
inner: Arc<BusInner>,
}
impl Bus {
pub fn new() -> Self {
Self {
inner: Arc::new(BusInner::default()),
}
}
#[inline]
pub fn register<
C: Send + Sync + 'static,
M: Message,
R: Message,
F: Future<Output = Result<(usize, C), Error>> + Send + 'static,
H: Future<Output = Result<impl IntoMessage<R>, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
&self,
queue: usize,
builder: impl Fn(u32) -> F + Send + 'static,
handler: impl Fn(Arc<C>, u32, u32, M) -> H + Clone + Send + 'static,
finalize: impl Fn(Arc<C>, u32) -> T + Clone + Send + 'static,
) -> &Self {
self.inner
.register_dispatcher(self, queue, builder, handler, finalize);
self
}
#[inline]
pub fn register_producer<
M: Message,
I: Message,
S: Stream<Item = Result<I, Error>> + Send + 'static,
F: Future<Output = Result<S, Error>> + Send + 'static,
T: Future<Output = Result<(), Error>> + Send + 'static,
>(
&self,
queue_size: usize,
builder: impl Fn(u32, M) -> F + Send + 'static,
finalize: impl Fn(u32, Pin<&mut S>) -> T + Clone + Send + 'static,
) -> &Self {
self.inner
.register_producer(self, queue_size, builder, finalize);
self
}
#[inline]
pub async fn send<M: Message>(&self, msg: M) -> Result<(), Error> {
self.inner.send(DEFAUL_STREAM_ID, msg).await
}
#[inline]
pub async fn send_with_stream<M: Message>(&self, stream_id: u32, msg: M) -> Result<(), Error> {
self.inner.send(stream_id, msg).await
}
///
/// Terminating everything ASAP
#[inline]
pub async fn shutdown(&self) {
self.inner.close(true).await;
}
///
/// Closing providers and waiting when queues were drained
#[inline]
pub async fn close(&self) {
self.inner.close(false).await;
}
///
/// Waiting the processing been done
#[inline]
pub fn wait(&self) -> impl Future<Output = ()> + '_ {
self.inner.wait()
}
}
#[cfg(test)]
mod tests {
use crate::Bus;
use async_stream::stream;
#[tokio::test]
async fn test() {
let bus = Bus::default();
bus.register(
1,
move |_| async move { Ok((1, ())) },
move |_, sid, _, msg: u64| async move {
println!("MSG: {} {}", sid, msg);
Ok(())
},
|_, s| async move {
println!("handler {} finalized", s);
Ok(())
},
);
bus.register_producer(
1,
move |_, _msg: u32| async move {
Ok(stream! {
for i in 0u64..10 {
yield Ok( i)
}
})
},
|s, _| async move {
println!("producer {} finalized", s);
Ok(())
},
);
for start in 0u32..10 {
bus.send_with_stream(start, start).await.unwrap();
}
bus.wait().await;
}
}