refactoring receivers +speed ++clean

This commit is contained in:
Andrey Tkachenko 2021-09-15 19:54:59 +04:00
parent da6dc04f71
commit 81e5376fe2
20 changed files with 494 additions and 580 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "messagebus"
version = "0.9.2"
version = "0.9.3"
authors = ["Andrey Tkachenko <andrey@aidev.ru>"]
repository = "https://github.com/andreytkachenko/messagebus.git"
keywords = ["futures", "async", "tokio", "message", "bus"]

View File

@ -167,7 +167,6 @@ impl Bus {
}
pub async fn flush_and_sync(&self) {
self.flush().await;
let _handle = self.inner.maintain.lock().await;
for (_, rs) in &self.inner.receivers {

View File

@ -125,6 +125,7 @@ pub trait ReceiverTrait: TypeTagAccept + Send + Sync {
fn try_reserve(&self, tt: &TypeTag) -> Option<Permit>;
fn reserve_notify(&self, tt: &TypeTag) -> Arc<Notify>;
fn increment_processing(&self, tt: &TypeTag);
fn start_polling(
self: Arc<Self>,
@ -466,6 +467,10 @@ where
) -> Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>> {
self.start_polling_events()
}
fn increment_processing(&self, _tt: &TypeTag) {
self.context.processing.fetch_add(1, Ordering::SeqCst);
}
}
pub struct Permit {
@ -799,6 +804,8 @@ impl Receiver {
msg: M,
req: bool,
) -> Result<(), Error<M>> {
self.inner.increment_processing(&M::type_tag_());
let res = if let Some(any_receiver) = self.inner.typed() {
any_receiver
.cast_send_typed::<M>()
@ -826,8 +833,8 @@ impl Receiver {
mut permit: Permit,
) -> Result<(), Error<Box<dyn Message>>> {
let res = self.inner.send_boxed(mid, msg, req, bus);
self.inner.set_need_flush();
permit.fuse = true;
self.inner.set_need_flush();
res
}

View File

@ -12,20 +12,20 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
AsyncHandler, Bus, Message, Untyped,
};
use super::{BufferUnorderedConfig, BufferUnorderedStats};
use futures::{stream::FuturesUnordered, Future, StreamExt};
use futures::Future;
use parking_lot::Mutex;
use tokio::sync::mpsc;
buffer_unordered_poller_macro!(
T,
AsyncHandler,
|msg, bus, ut: Arc<T>| tokio::spawn(async move { ut.handle(msg, &bus).await }),
|msg, bus, ut: Arc<T>| async move { ut.handle(msg, &bus).await },
|bus, ut: Arc<T>| { async move { ut.sync(&bus).await } }
);

View File

@ -1,11 +1,12 @@
mod r#async;
mod sync;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, atomic::AtomicU64};
pub use r#async::BufferUnorderedAsync;
use serde_derive::{Deserialize, Serialize};
pub use sync::BufferUnorderedSync;
use tokio::sync::{RwLock, Semaphore};
#[derive(Debug)]
pub struct BufferUnorderedStats {
@ -30,6 +31,12 @@ impl Default for BufferUnorderedConfig {
}
}
#[derive(Clone)]
struct ConcurrentState {
flush_lock: Arc<RwLock<()>>,
semaphore: Arc<Semaphore>,
}
#[macro_export]
macro_rules! buffer_unordered_poller_macro {
($t: tt, $h: tt, $st1: expr, $st2: expr) => {
@ -37,7 +44,7 @@ macro_rules! buffer_unordered_poller_macro {
mut rx: mpsc::UnboundedReceiver<Request<M>>,
bus: Bus,
ut: Untyped,
stats: Arc<BufferUnorderedStats>,
_stats: Arc<BufferUnorderedStats>,
cfg: BufferUnorderedConfig,
stx: mpsc::UnboundedSender<Event<R, E>>,
) -> impl Future<Output = ()>
@ -48,132 +55,54 @@ macro_rules! buffer_unordered_poller_macro {
E: StdSyncSendError,
{
let ut = ut.downcast::<$t>().unwrap();
let mut queue = FuturesUnordered::new();
let mut sync_future = None;
let mut need_sync = false;
let mut rx_closed = false;
let mut need_flush = false;
futures::future::poll_fn(move |cx| 'main: loop {
if !rx_closed && !need_flush && !need_sync {
while queue.len() < cfg.max_parallel {
match rx.poll_recv(cx) {
Poll::Ready(Some(a)) => match a {
Request::Request(mid, msg, req) => {
stats.buffer.fetch_sub(1, Ordering::Relaxed);
stats.parallel.fetch_add(1, Ordering::Relaxed);
let state = super::ConcurrentState {
flush_lock: Arc::new(tokio::sync::RwLock::new(())),
semaphore: Arc::new(tokio::sync::Semaphore::new(cfg.max_parallel)),
};
let bus = bus.clone();
let ut = ut.clone();
async move {
while let Some(msg) = rx.recv().await {
match msg {
Request::Request(mid, msg, _req) => {
let bus = bus.clone();
let ut = ut.clone();
let state = state.clone();
let stx = stx.clone();
queue.push(
async move {
let resp = ($st1)(
msg, bus, ut,
).await.unwrap();
let task_permit = state.semaphore.acquire_owned().await;
let flush_permit = state.flush_lock.read_owned().await;
(mid, req, resp)
}
);
}
Request::Action(Action::Init) => {
stx.send(Event::Ready).ok();
}
Request::Action(Action::Flush) => {
need_flush = true
},
Request::Action(Action::Close) => {
rx.close()
}
Request::Action(Action::Sync) => {
need_sync = true;
break;
}
_ => unimplemented!(),
},
Poll::Ready(None) => {
need_sync = true;
rx_closed = true;
break;
}
Poll::Pending => break,
}
}
}
let queue_len = queue.len();
loop {
if queue_len != 0 {
let mut finished = 0;
loop {
match queue.poll_next_unpin(cx) {
Poll::Ready(Some((mid, req, resp))) => {
if req {
let resp: Result<_, $t::Error> = resp;
stx.send(Event::Response(mid, resp.map_err(Error::Other)))
.ok();
} else {
finished += 1;
}
}
Poll::Pending => {
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
}
return Poll::Pending
},
let _ = tokio::spawn(async move {
let resp = ($st1)(msg, bus, ut)
.await;
Poll::Ready(None) => break,
}
drop(task_permit);
drop(flush_permit);
stx.send(Event::Response(mid, resp.map_err(Error::Other)))
.unwrap();
});
}
Request::Action(Action::Init) => { stx.send(Event::Ready).unwrap(); }
Request::Action(Action::Close) => { rx.close(); }
Request::Action(Action::Flush) => {
state.flush_lock.write().await;
stx.send(Event::Flushed).unwrap();
}
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
Request::Action(Action::Sync) => {
let lock = state.flush_lock.write().await;
let resp = ($st2)(bus.clone(), ut.clone()).await;
drop(lock);
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.unwrap();
}
}
if need_flush {
need_flush = false;
stx.send(Event::Flushed).ok();
}
if need_sync {
if let Some(fut) = sync_future.as_mut() {
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Pending => {
return Poll::Pending
},
Poll::Ready(resp) => {
let resp: Result<_, E> = resp;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.ok();
}
}
need_sync = false;
sync_future = None;
continue 'main;
} else {
sync_future.replace(($st2)(bus.clone(), ut.clone()));
}
} else {
break;
_ => unimplemented!(),
}
}
if queue_len == queue.len() {
return if rx_closed {
Poll::Ready(())
} else {
Poll::Pending
};
}
})
}
}
};
}

View File

@ -13,18 +13,22 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
Bus, Handler, Message, Untyped,
};
use futures::{stream::FuturesUnordered, Future, StreamExt};
use futures::Future;
use parking_lot::Mutex;
use tokio::sync::mpsc;
buffer_unordered_poller_macro!(
T,
Handler,
|msg, bus, ut: Arc<T>| tokio::task::spawn_blocking(move || ut.handle(msg, &bus)),
|msg, bus, ut: Arc<T>| async move {
tokio::task::spawn_blocking(move || ut.handle(msg, &bus))
.await
.unwrap()
},
|bus, ut: Arc<T>| async move {
tokio::task::spawn_blocking(move || ut.sync(&bus))
.await

View File

@ -12,26 +12,19 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
AsyncBatchHandler, Bus, Message, Untyped,
};
use super::{BufferUnorderedBatchedConfig, BufferUnorderedBatchedStats};
use futures::{stream::FuturesUnordered, Future, StreamExt};
use futures::Future;
use parking_lot::Mutex;
use tokio::sync::mpsc;
buffer_unordered_batch_poller_macro!(
T,
AsyncBatchHandler,
|buffer_clone, bus, ut: Arc<T>, stats: Arc<BufferUnorderedBatchedStats>, buffer_mid_clone| {
async move {
let resp = ut.handle(buffer_clone, &bus).await;
stats.parallel.fetch_sub(1, Ordering::Relaxed);
(buffer_mid_clone, resp)
}
},
|msgs, bus, ut: Arc<T>| async move { ut.handle(msgs, &bus).await },
|bus, ut: Arc<T>| { async move { ut.sync(&bus).await } }
);

View File

@ -1,11 +1,12 @@
mod r#async;
mod sync;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, atomic::AtomicU64};
pub use r#async::BufferUnorderedBatchedAsync;
use serde_derive::{Deserialize, Serialize};
pub use sync::BufferUnorderedBatchedSync;
use tokio::sync::{RwLock, Semaphore};
#[derive(Debug)]
pub struct BufferUnorderedBatchedStats {
@ -36,6 +37,12 @@ impl Default for BufferUnorderedBatchedConfig {
}
}
#[derive(Clone)]
struct ConcurrentState {
flush_lock: Arc<RwLock<()>>,
semaphore: Arc<Semaphore>,
}
#[macro_export]
macro_rules! buffer_unordered_batch_poller_macro {
($t: tt, $h: tt, $st1: expr, $st2: expr) => {
@ -43,7 +50,7 @@ macro_rules! buffer_unordered_batch_poller_macro {
mut rx: mpsc::UnboundedReceiver<Request<M>>,
bus: Bus,
ut: Untyped,
stats: Arc<BufferUnorderedBatchedStats>,
_stats: Arc<BufferUnorderedBatchedStats>,
cfg: BufferUnorderedBatchedConfig,
stx: mpsc::UnboundedSender<Event<R, $t::Error>>,
) -> impl Future<Output = ()>
@ -53,169 +60,137 @@ macro_rules! buffer_unordered_batch_poller_macro {
R: Message,
{
let ut = ut.downcast::<$t>().unwrap();
let mut buffer_mid = Vec::with_capacity(cfg.batch_size);
let mut buffer = Vec::with_capacity(cfg.batch_size);
let mut queue = FuturesUnordered::new();
let mut sync_future = None;
let mut need_sync = false;
let mut rx_closed = false;
let mut need_flush = false;
futures::future::poll_fn(move |cx| 'main: loop {
if !rx_closed && !need_flush && !need_sync {
while queue.len() < cfg.max_parallel {
let mut do_break = false;
let mut drained = false;
let state = super::ConcurrentState {
flush_lock: Arc::new(tokio::sync::RwLock::new(())),
semaphore: Arc::new(tokio::sync::Semaphore::new(cfg.max_parallel)),
};
match rx.poll_recv(cx) {
Poll::Ready(Some(a)) => match a {
Request::Request(mid, msg, req) => {
stats.buffer.fetch_sub(1, Ordering::Relaxed);
stats.batch.fetch_add(1, Ordering::Relaxed);
async move {
let mut buffer_mid = Vec::with_capacity(cfg.batch_size);
let mut buffer = Vec::with_capacity(cfg.batch_size);
buffer_mid.push((mid, req));
buffer.push(msg);
}
Request::Action(Action::Init) => {
stx.send(Event::Ready).ok();
}
Request::Action(Action::Flush) => need_flush = true,
Request::Action(Action::Close) => rx.close(),
Request::Action(Action::Sync) => {
need_sync = true;
do_break = true;
}
_ => unimplemented!(),
},
Poll::Ready(None) => {
need_sync = true;
rx_closed = true;
do_break = true;
}
Poll::Pending => {
drained = true;
do_break = true;
}
}
while let Some(msg) = rx.recv().await {
let bus = bus.clone();
let ut = ut.clone();
let state = state.clone();
let stx = stx.clone();
match msg {
Request::Request(mid, msg, req) => {
buffer_mid.push((mid, req));
buffer.push(msg);
if !buffer.is_empty()
&& (rx_closed
|| need_flush
|| need_sync
|| (drained && cfg.when_ready)
|| buffer.len() >= cfg.batch_size)
{
stats.batch.store(0, Ordering::Relaxed);
stats.parallel.fetch_add(1, Ordering::Relaxed);
if buffer_mid.len() >= cfg.batch_size {
let task_permit = state.semaphore.acquire_owned().await;
let flush_permit = state.flush_lock.read_owned().await;
let bus = bus.clone();
let ut = ut.clone();
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
let stats = stats.clone();
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
queue.push(($st1)(buffer_clone, bus, ut, stats, buffer_mid_clone));
}
let _ = tokio::spawn(async move {
let resp = ($st1)(buffer_clone, bus, ut).await;
drop(task_permit);
drop(flush_permit);
if do_break {
break;
}
}
}
let mids = buffer_mid_clone.into_iter();
match resp {
Ok(re) => {
let mut mids = mids.into_iter();
let mut re = re.into_iter();
let queue_len = queue.len();
loop {
if queue_len != 0 {
let mut finished = 0;
loop {
match queue.poll_next_unpin(cx) {
Poll::Ready(Some((mids, res))) => match res {
Ok(re) => {
let mut mids = mids.into_iter();
let mut re = re.into_iter();
while let Some((mid, req)) = mids.next() {
if req {
while let Some((mid, _req)) = mids.next() {
if let Some(r) = re.next() {
stx.send(Event::Response(mid, Ok(r))).ok();
stx.send(Event::Response(mid, Ok(r)))
.unwrap();
} else {
stx.send(Event::Response(
mid,
Err(Error::NoResponse),
))
.ok();
stx.send(Event::Response(mid, Err(Error::NoResponse)))
.unwrap();
}
} else {
finished += 1;
}
}
}
Err(er) => {
for (mid, req) in mids {
if req {
Err(er) => {
for (mid, _req) in mids {
stx.send(Event::Response(
mid,
Err(Error::Other(er.clone())),
))
.ok();
} else {
finished += 1
)).unwrap();
}
stx.send(Event::Error(er)).unwrap();
}
}
});
}
}
Request::Action(Action::Init) => { stx.send(Event::Ready).unwrap(); }
Request::Action(Action::Close) => { rx.close(); }
Request::Action(Action::Flush) => {
let stx_clone = stx.clone();
if !buffer_mid.is_empty() {
let task_permit = state.semaphore.acquire_owned().await;
let flush_permit = state.flush_lock.clone().read_owned().await;
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
let _ = tokio::spawn(async move {
let resp = ($st1)(buffer_clone, bus, ut).await;
drop(task_permit);
drop(flush_permit);
let mids = buffer_mid_clone.into_iter();
match resp {
Ok(re) => {
let mut mids = mids.into_iter();
let mut re = re.into_iter();
while let Some((mid, _req)) = mids.next() {
if let Some(r) = re.next() {
stx.send(Event::Response(mid, Ok(r)))
.unwrap();
} else {
stx.send(Event::Response(mid, Err(Error::NoResponse)))
.unwrap();
}
}
}
stx.send(Event::Error(er)).ok();
Err(er) => {
for (mid, _req) in mids {
stx.send(Event::Response(
mid,
Err(Error::Other(er.clone())),
)).unwrap();
}
stx.send(Event::Error(er)).unwrap();
}
}
},
Poll::Ready(None) => break,
Poll::Pending => {
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
}
return Poll::Pending
},
});
}
}
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
}
}
if need_flush {
need_flush = false;
stx.send(Event::Flushed).ok();
}
if need_sync {
if let Some(fut) = sync_future.as_mut() {
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(resp) => {
let resp: Result<_, $t::Error> = resp;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.ok();
}
}
need_sync = false;
sync_future = None;
continue 'main;
} else {
sync_future.replace(($st2)(bus.clone(), ut.clone()));
state.flush_lock.write().await;
stx_clone.send(Event::Flushed).unwrap();
}
} else {
break;
Request::Action(Action::Sync) => {
let lock = state.flush_lock.write().await;
let resp = ($st2)(bus.clone(), ut.clone()).await;
drop(lock);
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.unwrap();
}
_ => unimplemented!(),
}
}
if queue_len == queue.len() {
return if rx_closed {
Poll::Ready(())
} else {
Poll::Pending
};
}
})
}
}
};
}

View File

@ -13,27 +13,21 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
BatchHandler, Bus, Message, Untyped,
};
use futures::{stream::FuturesUnordered, Future, StreamExt};
use futures::Future;
use parking_lot::Mutex;
use tokio::sync::mpsc;
buffer_unordered_batch_poller_macro!(
T,
BatchHandler,
|buffer_clone, bus, ut: Arc<T>, stats: Arc<BufferUnorderedBatchedStats>, buffer_mid_clone| {
async move {
let resp = tokio::task::spawn_blocking(move || {
(buffer_mid_clone, ut.handle(buffer_clone, &bus))
})
|msgs, bus, ut: Arc<T>| async move {
tokio::task::spawn_blocking(move || ut.handle(msgs, &bus))
.await
.unwrap();
stats.parallel.fetch_sub(1, Ordering::Relaxed);
resp
}
.unwrap()
},
|bus, ut: Arc<T>| {
async move {

View File

@ -3,9 +3,6 @@ mod buffer_unordered_batched;
mod synchronize_batched;
mod synchronized;
use futures::Future;
use std::pin::Pin;
pub use buffer_unordered::{BufferUnorderedAsync, BufferUnorderedConfig, BufferUnorderedSync};
pub use buffer_unordered_batched::{
BufferUnorderedBatchedAsync, BufferUnorderedBatchedConfig, BufferUnorderedBatchedSync,
@ -18,23 +15,6 @@ pub use synchronize_batched::{
use crate::receiver::Action;
#[inline(always)]
pub(crate) unsafe fn fix_type<'a, F, T>(
x: &'a mut F,
) -> Pin<&'a mut (impl Future<Output = T> + Send)>
where
F: Future<Output = T> + Send,
{
Pin::new_unchecked(x)
}
#[inline(always)]
pub(crate) fn fix_into_iter<I, T: IntoIterator<Item = I> + Send>(
x: T,
) -> impl IntoIterator<Item = I> + Send {
x
}
#[derive(Debug)]
pub(crate) enum Request<M> {
Action(Action),

View File

@ -9,7 +9,7 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_into_iter, fix_type, Request},
receivers::Request,
AsyncBatchSynchronizedHandler, Bus, Message, Untyped,
};
@ -22,10 +22,8 @@ use tokio::sync::{mpsc, Mutex};
batch_synchronized_poller_macro! {
T,
AsyncBatchSynchronizedHandler,
|mids, buffer, bus, ut: Arc<Mutex<T>>| {
async move {
(mids, ut.lock().await.handle(buffer, &bus).await)
}
|msgs, bus, ut: Arc<Mutex<T>>| async move {
ut.lock().await.handle(msgs, &bus).await
},
|bus, ut: Arc<Mutex<T>>| { async move { ut.lock().await.sync(&bus).await } }
}

View File

@ -50,159 +50,113 @@ macro_rules! batch_synchronized_poller_macro {
R: Message,
{
let ut = ut.downcast::<Mutex<T>>().unwrap();
let mut buffer_mid = Vec::with_capacity(cfg.batch_size);
let mut buffer = Vec::with_capacity(cfg.batch_size);
let mut handle_future = None;
let mut sync_future = None;
let mut need_sync = false;
let mut rx_closed = false;
let mut need_flush = false;
futures::future::poll_fn(move |cx| loop {
if let Some(fut) = handle_future.as_mut() {
// SAFETY: safe bacause pinnet to async generator `stack` which should be pinned
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Ready((mids, res)) => match res {
Ok(re) => {
let mids: Vec<(u64, bool)> = mids;
let re = fix_into_iter::<R, _>(re);
async move {
let mut buffer_mid = Vec::with_capacity(cfg.batch_size);
let mut buffer = Vec::with_capacity(cfg.batch_size);
let mut mids = mids.into_iter();
let mut re = re.into_iter();
let mut finished = 0;
while let Some(msg) = rx.recv().await {
let bus = bus.clone();
let ut = ut.clone();
let stx = stx.clone();
match msg {
Request::Request(mid, msg, req) => {
buffer_mid.push((mid, req));
buffer.push(msg);
while let Some((mid, req)) = mids.next() {
if req {
if let Some(r) = re.next() {
stx.send(Event::Response(mid, Ok(r))).ok();
} else {
stx.send(Event::Response(mid, Err(Error::NoResponse)))
.ok();
if buffer_mid.len() >= cfg.batch_size {
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
tokio::spawn(async move {
let resp = ($st1)(buffer_clone, bus, ut).await;
let mids = buffer_mid_clone.into_iter();
match resp {
Ok(re) => {
let mut mids = mids.into_iter();
let mut re = re.into_iter();
while let Some((mid, _req)) = mids.next() {
if let Some(r) = re.next() {
stx.send(Event::Response(mid, Ok(r)))
.unwrap();
} else {
stx.send(Event::Response(mid, Err(Error::NoResponse)))
.unwrap();
}
}
}
Err(er) => {
for (mid, _req) in mids {
stx.send(Event::Response(
mid,
Err(Error::Other(er.clone())),
)).unwrap();
}
stx.send(Event::Error(er)).unwrap();
}
} else {
finished += 1;
}
}
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
}
}).await.unwrap();
}
}
Request::Action(Action::Init) => { stx.send(Event::Ready).unwrap(); }
Request::Action(Action::Close) => { rx.close(); }
Request::Action(Action::Flush) => {
let stx_clone = stx.clone();
Err(er) => {
let er: $t::Error = er;
let mut finished = 0;
if !buffer_mid.is_empty() {
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
for (mid, req) in mids {
if req {
stx.send(Event::Response(
mid,
Err(Error::Other(er.clone())),
))
.ok();
} else {
finished += 1;
let _ = tokio::spawn(async move {
let resp = ($st1)(buffer_clone, bus, ut).await;
let mids = buffer_mid_clone.into_iter();
match resp {
Ok(re) => {
let mut mids = mids.into_iter();
let mut re = re.into_iter();
while let Some((mid, _req)) = mids.next() {
if let Some(r) = re.next() {
stx.send(Event::Response(mid, Ok(r)))
.unwrap();
} else {
stx.send(Event::Response(mid, Err(Error::NoResponse)))
.unwrap();
}
}
}
Err(er) => {
for (mid, _req) in mids {
stx.send(Event::Response(
mid,
Err(Error::Other(er.clone())),
)).unwrap();
}
stx.send(Event::Error(er)).unwrap();
}
}
}
stx.send(Event::Error(er)).ok();
if finished > 0 {
stx.send(Event::Finished(finished)).ok();
}
});
}
stx_clone.send(Event::Flushed).unwrap();
}
Poll::Pending => return Poll::Pending
}
}
handle_future = None;
while !rx_closed && !need_sync && !need_flush {
match rx.poll_recv(cx) {
Poll::Ready(Some(a)) => match a {
Request::Request(mid, msg, req) => {
// stats.buffer.fetch_sub(1, Ordering::Relaxed);
// stats.batch.fetch_add(1, Ordering::Relaxed);
buffer_mid.push((mid, req));
buffer.push(msg);
}
Request::Action(Action::Init) => {
stx.send(Event::Ready).ok();
}
Request::Action(Action::Flush) => need_flush = true,
Request::Action(Action::Sync) => need_sync = true,
Request::Action(Action::Close) => rx.close(),
_ => unimplemented!(),
},
Poll::Ready(None) => {
need_sync = true;
rx_closed = true;
Request::Action(Action::Sync) => {
let resp = ($st2)(bus.clone(), ut.clone()).await;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.unwrap();
}
Poll::Pending => {
if cfg.when_ready {
break;
} else {
return Poll::Pending;
}
}
}
if buffer_mid.len() >= cfg.batch_size {
break;
_ => unimplemented!(),
}
}
if !buffer.is_empty() {
// stats.batch.store(0, Ordering::Relaxed);
let buffer_mid_clone = buffer_mid.drain(..).collect::<Vec<_>>();
let buffer_clone = buffer.drain(..).collect();
handle_future.replace(($st1)(
buffer_mid_clone,
buffer_clone,
bus.clone(),
ut.clone(),
));
continue;
}
if need_flush {
need_flush = false;
stx.send(Event::Flushed).ok();
continue;
}
if need_sync {
if let Some(fut) = sync_future.as_mut() {
// SAFETY: safe bacause pinnet to async generator `stack` which should be pinned
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(resp) => {
need_sync = false;
let resp: Result<_, $t::Error> = resp;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.ok();
}
}
sync_future = None;
} else {
sync_future.replace(($st2)(bus.clone(), ut.clone()));
}
}
return if rx_closed {
Poll::Ready(())
} else {
Poll::Pending
};
})
}
}
};
}

View File

@ -9,7 +9,7 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_into_iter, fix_type, Request},
receivers::Request,
BatchSynchronizedHandler, Bus, Message, Untyped,
};
@ -20,17 +20,15 @@ use tokio::sync::{mpsc, Mutex};
batch_synchronized_poller_macro! {
T,
BatchSynchronizedHandler,
|mids, buffer, bus, ut: Arc<Mutex<T>>| async move {
(mids, tokio::task::spawn_blocking(move || {
block_on(ut.lock()).handle(buffer, &bus)
})
.await
.unwrap())
|msgs, bus, ut: Arc<Mutex<T>>| async move {
tokio::task::spawn_blocking(move || block_on(ut.lock()).handle(msgs, &bus))
.await
.unwrap()
},
|bus, ut: Arc<Mutex<T>>| async move {
tokio::task::spawn_blocking(move || block_on(ut.lock()).sync(&bus))
.await
.unwrap()
.await
.unwrap()
}
}

View File

@ -12,7 +12,7 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
AsyncSynchronizedHandler, Bus, Message, Untyped,
};
use tokio::sync::{mpsc, Mutex};
@ -20,8 +20,8 @@ use tokio::sync::{mpsc, Mutex};
synchronized_poller_macro! {
T,
AsyncSynchronizedHandler,
|mid, msg, bus, ut: Arc<Mutex<T>>, req: bool| async move {
(mid, req, ut.lock().await.handle(msg, &bus).await)
|msg, bus, ut: Arc<Mutex<T>>| async move {
ut.lock().await.handle(msg, &bus).await
},
|bus, ut: Arc<Mutex<T>>| async move {
ut.lock().await.sync(&bus).await

View File

@ -40,84 +40,38 @@ macro_rules! synchronized_poller_macro {
R: Message,
{
let ut = ut.downcast::<Mutex<T>>().unwrap();
let mut handle_future = None;
let mut sync_future = None;
let mut need_sync = false;
let mut rx_closed = false;
futures::future::poll_fn(move |cx| loop {
if let Some(fut) = handle_future.as_mut() {
// SAFETY: safe bacause pinnet to async generator `stack` which should be pinned
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready((mid, _req, resp)) => {
let resp: Result<_, $t::Error> = resp;
stx.send(Event::Response(mid, resp.map_err(Error::Other)))
.ok();
async move {
while let Some(msg) = rx.recv().await {
match msg {
Request::Request(mid, msg, _req) => {
let bus = bus.clone();
let ut = ut.clone();
let stx = stx.clone();
tokio::spawn(async move {
let resp = ($st1)(msg, bus, ut)
.await;
stx.send(Event::Response(mid, resp.map_err(Error::Other)))
.unwrap();
})
.await
.unwrap();
}
Request::Action(Action::Init) => { stx.send(Event::Ready).unwrap(); }
Request::Action(Action::Close) => { rx.close(); }
Request::Action(Action::Flush) => { stx.send(Event::Flushed).unwrap(); }
Request::Action(Action::Sync) => {
let resp = ($st2)(bus.clone(), ut.clone()).await;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.unwrap();
}
_ => unimplemented!(),
}
}
handle_future = None;
if !rx_closed && !need_sync {
match rx.poll_recv(cx) {
Poll::Ready(Some(a)) => match a {
Request::Request(mid, msg, req) => {
handle_future.replace(($st1)(
mid,
msg,
bus.clone(),
ut.clone(),
req,
));
continue;
}
Request::Action(Action::Flush) => {
stx.send(Event::Flushed).ok();
continue;
}
Request::Action(Action::Init) => {
stx.send(Event::Ready).ok();
}
Request::Action(Action::Sync) => need_sync = true,
Request::Action(Action::Close) => {
rx.close();
continue;
}
_ => unimplemented!(),
},
Poll::Ready(None) => {
need_sync = true;
rx_closed = true;
}
Poll::Pending => {}
}
}
if need_sync {
if let Some(fut) = sync_future.as_mut() {
// SAFETY: safe bacause pinnet to async generator `stack` which should be pinned
match unsafe { fix_type(fut) }.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(resp) => {
need_sync = false;
let resp: Result<_, $t::Error> = resp;
stx.send(Event::Synchronized(resp.map_err(Error::Other)))
.ok();
}
}
sync_future = None;
} else {
sync_future.replace(($st2)(bus.clone(), ut.clone()));
}
}
return if rx_closed {
Poll::Ready(())
} else {
Poll::Pending
};
})
}
}
};
}

View File

@ -12,7 +12,7 @@ use crate::{
builder::ReceiverSubscriberBuilder,
error::{Error, SendError, StdSyncSendError},
receiver::{Action, Event, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver},
receivers::{fix_type, Request},
receivers::Request,
Bus, Message, SynchronizedHandler, Untyped,
};
use tokio::sync::{mpsc, Mutex};
@ -20,17 +20,15 @@ use tokio::sync::{mpsc, Mutex};
synchronized_poller_macro! {
T,
SynchronizedHandler,
|mid, msg, bus, ut: Arc<Mutex<T>>, req: bool| async move {
(mid, req, tokio::task::spawn_blocking(move || {
block_on(ut.lock()).handle(msg, &bus)
})
.await
.unwrap())
|msg, bus, ut: Arc<Mutex<T>>| async move {
tokio::task::spawn_blocking(move || block_on(ut.lock()).handle(msg, &bus))
.await
.unwrap()
},
|bus, ut: Arc<Mutex<T>>| async move {
tokio::task::spawn_blocking(move || block_on(ut.lock()).sync(&bus))
.await
.unwrap()
.await
.unwrap()
}
}

View File

@ -216,6 +216,10 @@ where
self.context.init_sent.load(Ordering::SeqCst)
}
fn increment_processing(&self, tt: &TypeTag) {
self.context.receivers.get(tt).map(|r|r.processing.fetch_add(1, Ordering::SeqCst));
}
fn start_polling(
self: Arc<Self>,
) -> Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>> {

View File

@ -8,14 +8,14 @@ pub struct Stats {
pub err_type_tag: Option<Cow<'static, str>>,
pub has_queue: bool,
pub queue_capacity: u64,
pub queue_size: u64,
pub queue_capacity: i64,
pub queue_size: i64,
pub has_parallel: bool,
pub parallel_capacity: u64,
pub parallel_size: u64,
pub parallel_capacity: i64,
pub parallel_size: i64,
pub has_batch: bool,
pub batch_capacity: u64,
pub batch_size: u64,
pub batch_capacity: i64,
pub batch_size: i64,
}

View File

@ -1,10 +1,8 @@
use std::sync::Arc;
use async_trait::async_trait;
use messagebus::{
derive::{Error as MbError, Message},
error, AsyncBatchHandler, BatchHandler, Bus, Message,
};
use messagebus::{AsyncBatchHandler, Bus, Message, derive::{Error as MbError, Message}, error, receivers::BufferUnorderedBatchedConfig};
use parking_lot::Mutex;
use thiserror::Error;
#[derive(Debug, Error, Clone, MbError)]
@ -27,7 +25,9 @@ struct MsgI32(i32);
#[message(clone)]
struct MsgI16(i16);
struct TmpReceiver;
struct TmpReceiver {
batches: Arc<Mutex<Vec<Vec<i32>>>>
}
#[async_trait]
impl AsyncBatchHandler<MsgI32> for TmpReceiver {
@ -41,48 +41,58 @@ impl AsyncBatchHandler<MsgI32> for TmpReceiver {
msg: Vec<MsgI32>,
_bus: &Bus,
) -> Result<Vec<Self::Response>, Self::Error> {
println!("---> [i32; {}] {:?}", msg.len(), msg);
self.batches.lock().push(msg.into_iter().map(|x|x.0).collect());
Ok(vec![])
}
}
impl BatchHandler<MsgI16> for TmpReceiver {
type Error = Error;
type Response = ();
type InBatch = Vec<MsgI16>;
type OutBatch = Vec<()>;
#[tokio::test]
async fn test_batch() {
let batches = Arc::new(Mutex::new(Vec::new()));
fn handle(&self, msg: Vec<MsgI16>, _bus: &Bus) -> Result<Vec<Self::Response>, Self::Error> {
println!("---> [i16; {}] {:?}", msg.len(), msg);
Ok(vec![])
}
}
#[tokio::main]
async fn main() {
let (b, poller) = Bus::build()
.register(TmpReceiver)
.subscribe_batch_async::<MsgI32>(16, Default::default())
.subscribe_batch_sync::<MsgI16>(16, Default::default())
.register(TmpReceiver { batches: batches.clone() })
.subscribe_batch_async::<MsgI32>(16, BufferUnorderedBatchedConfig {
batch_size: 8,
..Default::default()
})
.done()
.build();
for i in 1..100i32 {
b.send(MsgI32(i)).await.unwrap();
}
let mut re = Vec::new();
let mut counter = 1i32;
for _ in 1..100i32 {
let mut v = Vec::new();
for _ in 0..8 {
if counter >= 100 {
break;
}
b.send(MsgI16(1i16)).await.unwrap();
b.send(MsgI16(2i16)).await.unwrap();
b.send(MsgI16(3i16)).await.unwrap();
b.send(MsgI16(4i16)).await.unwrap();
b.send(MsgI16(5i16)).await.unwrap();
b.send(MsgI16(6i16)).await.unwrap();
b.send(MsgI16(7i16)).await.unwrap();
v.push(counter);
counter += 1;
}
re.push(v);
if counter >= 100 {
break;
}
}
println!("flush");
b.flush().await;
let mut lock = batches.lock();
lock.sort_by(|a, b| a[0].cmp(&b[0]));
assert_eq!(lock.as_slice(), re.as_slice());
println!("close");
b.close().await;

117
tests/test_concurrency.rs Normal file
View File

@ -0,0 +1,117 @@
use std::{sync::atomic::AtomicU32, time::Duration};
use messagebus::{AsyncHandler, Bus, Message, Module, derive::{Error as MbError, Message}, error, receivers::BufferUnorderedConfig};
use thiserror::Error;
use async_trait::async_trait;
#[derive(Debug, Error, MbError)]
enum Error {
#[error("Error({0})")]
Error(anyhow::Error),
#[error("MyError")]
MyError,
}
impl<M: Message> From<error::Error<M>> for Error {
fn from(err: error::Error<M>) -> Self {
Self::Error(err.into())
}
}
#[derive(Debug, Clone, Message)]
struct Req(pub u32);
#[derive(Debug, Clone, Message)]
struct Resp(pub u32);
#[derive(Debug, Clone, Message)]
struct GetCount;
#[derive(Debug, Clone, Message)]
struct CountResult(pub u32);
struct TmpReceiver {
counter: AtomicU32
}
#[async_trait]
impl AsyncHandler<Req> for TmpReceiver {
type Error = Error;
type Response = ();
async fn handle(&self, msg: Req, bus: &Bus) -> Result<Self::Response, Self::Error> {
tokio::time::sleep(Duration::from_millis((msg.0 % 20) as _))
.await;
if msg.0 % 128 == 0 {
return Err(Error::MyError);
} else {
bus.send(Resp(msg.0)).await?;
}
Ok(())
}
}
#[async_trait]
impl AsyncHandler<Resp> for TmpReceiver {
type Error = Error;
type Response = ();
async fn handle(&self, _msg: Resp, _bus: &Bus) -> Result<Self::Response, Self::Error> {
self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
}
#[async_trait]
impl AsyncHandler<GetCount> for TmpReceiver {
type Error = Error;
type Response = CountResult;
async fn handle(&self, _: GetCount, _bus: &Bus) -> Result<Self::Response, Self::Error> {
Ok(CountResult(self.counter.load(std::sync::atomic::Ordering::SeqCst)))
}
}
fn module() -> Module {
Module::new()
.register(TmpReceiver { counter: AtomicU32::new(0) })
.subscribe_async::<Req>(1024, BufferUnorderedConfig {
buffer_size: 1024,
max_parallel: 1024,
})
.subscribe_async::<Resp>(1024, Default::default())
.subscribe_async::<GetCount>(8, Default::default())
.done()
}
#[tokio::test]
async fn test_sync() {
let (b, poller) = Bus::build().add_module(module()).build();
let cnt = 4u32;
for i in 0..cnt{
for j in 0..32768 {
b.send(Req(i * 128 + j)).await.unwrap();
}
println!("{} sent", i);
}
println!("sent");
b.flush().await;
println!("flushed");
assert_eq!(b.request_we::<_, CountResult, Error>(GetCount, Default::default()).await.unwrap().0, cnt * 32768 - cnt * 256);
b.close().await;
println!("closed");
poller.await;
}