Reordering queue part 2

This commit is contained in:
Andrey Tkachenko 2023-11-13 17:57:01 +04:00
parent be2c24e349
commit 4d1cced70e
11 changed files with 571 additions and 464 deletions

View File

@ -11,6 +11,7 @@ dashmap = "5.5.0"
futures = "0.3.28"
kanal = "0.1.0-pre8"
log = "0.4.20"
pin-project-lite = "0.2.13"
priority-queue = "1.3.2"
rand = { version = "0.8.5", default-features = false, features = ["std_rng", "std"] }
rand_xorshift = "0.3.0"

View File

@ -3,7 +3,7 @@
use std::sync::Arc;
use messagebus::{Builder, Bus, Error, Handler, IntoMessage, Message};
use messagebus::{Builder, Bus, Error, Handler, IntoMessageStream, Message};
#[derive(Debug, Clone)]
pub struct Msg(pub i32);
@ -20,7 +20,7 @@ impl Handler<Msg> for Processor {
_msg: Msg,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
Ok(())
}
@ -33,7 +33,7 @@ impl Handler<Msg> for Processor {
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
Ok(None)
}
}

View File

@ -57,6 +57,16 @@ impl<M: Message, H> DefaultBuilder<M, H> {
_m: PhantomData,
}
}
pub fn ordered(self) -> Self {
let mut config = self.config;
config.ordered = true;
Self {
config,
_m: PhantomData,
}
}
}
impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuilder<M, H> {

View File

@ -1,4 +1,7 @@
use std::{any::Any, pin::Pin};
use std::{
any::{Any, TypeId},
pin::Pin,
};
use futures::Future;
@ -9,6 +12,26 @@ enum ChannelItem<T> {
Close,
}
pub(crate) trait AbstractSender: Any + Send + Sync {
fn upcast(&self) -> &(dyn Any + Send + Sync);
fn close(&self);
fn message_type_id(&self) -> TypeId;
}
impl<M: Any + Send + Sync> AbstractSender for kanal::AsyncSender<M> {
fn upcast(&self) -> &(dyn Any + Send + Sync) {
self
}
fn close(&self) {
self.close();
}
fn message_type_id(&self) -> TypeId {
TypeId::of::<M>()
}
}
pub(crate) trait BusSenderClose: Any + Send + Sync {
fn upcast(&self) -> &(dyn Any + Send + Sync);
fn is_producer(&self) -> bool;

View File

@ -1,3 +1,5 @@
use std::fmt;
use kanal::ReceiveError;
#[derive(Debug, PartialEq, Eq)]
@ -6,9 +8,29 @@ pub enum Error {
Aborted,
SendError(String),
ReceiveError(kanal::ReceiveError),
ReorderingMissedMessage(u64),
ReorderingDropMessage(u64),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::HandlerIsNotRegistered => writeln!(f, "Handle is not registered!")?,
Error::Aborted => writeln!(f, "Operation Aborted!")?,
Error::SendError(reason) => writeln!(f, "Channel send error; reason {}", reason)?,
Error::ReceiveError(reason) => writeln!(f, "Channel receive error; reason {}", reason)?,
Error::ReorderingDropMessage(index) => writeln!(
f,
"Reordering drop message #{} (out of bound the queue)",
index
)?,
}
Ok(())
}
}
impl std::error::Error for Error {}
impl Clone for Error {
fn clone(&self) -> Self {
match self {
@ -19,7 +41,7 @@ impl Clone for Error {
},
Error::HandlerIsNotRegistered => Error::HandlerIsNotRegistered,
Error::Aborted => Error::Aborted,
Error::ReorderingMissedMessage(idx) => Error::ReorderingMissedMessage(*idx),
Error::ReorderingDropMessage(idx) => Error::ReorderingDropMessage(*idx),
}
}
}

View File

@ -1,6 +1,14 @@
use std::{any::Any, marker::PhantomData, pin::Pin, sync::Arc};
use std::{
any::{type_name, Any},
marker::PhantomData,
pin::{pin, Pin},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use futures::Future;
use futures::{future, Future, Stream, StreamExt};
use tokio::sync::Notify;
use crate::{
@ -8,7 +16,7 @@ use crate::{
chan::Receiver,
message::Msg,
task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message,
BusInner, Error, IntoMessageStream, Message,
};
pub trait Handler<M: Message>: Send + Sync + 'static {
@ -19,14 +27,17 @@ pub trait Handler<M: Message>: Send + Sync + 'static {
msg: M,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Result>, Error>> + Send + '_;
) -> impl Future<Output = Result<impl IntoMessageStream<Self::Result> + '_, Error>> + Send + '_;
fn handle_error(
&mut self,
err: Error,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Result>, Error>> + Send + '_;
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> impl Future<Output = Result<impl IntoMessageStream<Self::Result> + '_, Error>> + Send + '_
{
future::ready(Ok(None))
}
fn finalize(self) -> impl Future<Output = Result<(), Error>> + Send;
}
@ -56,47 +67,82 @@ where
task_id: u32,
_abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
spawn_counter: Arc<TaskCounter>,
index_counter: Arc<AtomicU64>,
bus: Arc<BusInner>,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async move {
let bus = bus.clone();
let config = self.builder.config(stream_id);
let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move {
let _test = spawn_counter.clone().lease_unit(|| true);
while let Some(msg) = rx.recv().await {
task_counter.inc_running();
let _test = task_counter.clone().lease_unit(|| rx.is_empty());
let resp = match msg.inner {
Ok(m) => ctx
.handle(m, stream_id, task_id)
let res = match msg.inner {
Some(Ok(m)) => {
send_result(
&bus,
&index_counter,
msg.index,
stream_id,
&config,
Some(
ctx.handle(m, stream_id, task_id)
.await
.map(IntoMessageStream::into_message_stream),
),
)
.await
.map(IntoMessage::into_message),
Err(err) => ctx
.handle_error(err, stream_id, task_id)
}
Some(Err(err)) => {
send_result(
&bus,
&index_counter,
msg.index,
stream_id,
&config,
Some(
ctx.handle_error(err, stream_id, task_id)
.await
.map(IntoMessageStream::into_message_stream),
),
)
.await
.map(IntoMessage::into_message),
}
None => {
send_result::<<B::Context as Handler<M>>::Result>(
&bus,
&index_counter,
msg.index,
stream_id,
&config,
None::<Result<futures::stream::Empty<_>, _>>,
)
.await
}
};
let Some(inner) = resp.transpose() else {
let _ = bus.send_skip::<M>(stream_id, msg.index).await;
continue;
};
if let Err(err) = bus
.send(Msg {
inner,
index: msg.index,
if let Err(err) = res {
println!(
"Messagebus Send Error: {}/{} {}: {}",
stream_id,
})
.await
{
println!("BUS SEND ERROR: {:?}", err);
task_id,
type_name::<<B::Context as Handler<M>>::Result>(),
err,
);
}
task_counter.dec_running(rx.is_empty());
}
println!(
"TASK #{} of type `{}` ENDED",
task_id,
std::any::type_name::<B>()
);
if let Err(err) = ctx.finalize().await {
println!("TASK FINALIZE ERROR: {:?}", err);
}
@ -113,3 +159,47 @@ where
self.builder.config(stream_id)
}
}
async fn send_result<'a, M: Message>(
bus: &Arc<BusInner>,
index_counter: &AtomicU64,
index: u64,
stream_id: u32,
config: &crate::builder::Config,
res: Option<Result<impl Stream<Item = Result<M, Error>> + Send + 'a, Error>>,
) -> Result<(), Error> {
let reorder_buff = if config.ordered && config.task_count > 1 {
config.task_count
} else {
0
};
let one = match res {
Some(Ok(stream)) => {
let hint = stream.size_hint();
match hint {
(0, Some(0)) => None,
(1, Some(1)) => {
let mut stream = pin!(stream);
stream.next().await
}
_ => {
let mut stream = pin!(stream);
while let Some(item) = stream.next().await {
let index = index_counter.fetch_add(1, Ordering::Relaxed);
bus.send::<M>(Some(item), index, stream_id, reorder_buff)
.await?;
}
return Ok(());
}
}
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
bus.send(one, index, stream_id, reorder_buff).await?;
Ok(())
}

View File

@ -6,7 +6,6 @@ mod chan;
mod error;
mod handler;
mod message;
mod producer;
mod rand;
mod reorder_queue;
mod task;
@ -15,38 +14,40 @@ use std::{
any::{Any, TypeId},
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
};
use builder::Config;
use chan::{channel, BusSender, BusSenderClose, Receiver, Sender};
use chan::{channel, AbstractSender, BusSender, BusSenderClose, Receiver, Sender};
use dashmap::DashMap;
use futures::Future;
use message::Msg;
use rand::RndGen;
use reorder_queue::ReorderQueue;
use reorder_queue::ReorderQueueInner;
use task::{TaskCounter, TaskSpawnerWrapper};
use tokio::sync::{Notify, RwLock};
pub use builder::{Builder, DefaultBuilder, SharedBuilder};
pub use error::Error;
pub use handler::Handler;
pub use message::{IntoMessage, Message};
pub use producer::Producer;
pub use message::{IntoMessageStream, Message, MessageIterator, MessageStream, MessageTryIterator};
pub const DEFAUL_STREAM_ID: u32 = u32::MAX;
pub const DEFAUL_TASK_ID: u32 = 0;
#[derive(Default)]
struct BusInner {
senders: DashMap<(u32, u32, TypeId), Arc<dyn BusSenderClose + 'static>>,
spawners: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync + 'static>>>,
reordering: DashMap<(u32, TypeId), Box<dyn Any + Send + Sync + 'static>>,
counters: DashMap<TypeId, AtomicU64>,
senders: DashMap<(u32, u32, TypeId), Arc<dyn BusSenderClose>>,
spawners: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
reordering: DashMap<(u32, TypeId), Arc<dyn AbstractSender>>,
counters: DashMap<(u32, TypeId), Arc<AtomicU64>>,
abort_notify: Arc<Notify>,
task_counter: Arc<TaskCounter>,
spawn_counter: Arc<TaskCounter>,
stopping: AtomicBool,
stopping_notify: Arc<Notify>,
rng: RndGen,
}
@ -91,6 +92,7 @@ impl BusInner {
stream_id: u32,
task_id: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let sender = spawner
.spawn_task(
chan,
@ -98,21 +100,58 @@ impl BusInner {
task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.spawn_counter.clone(),
self.counters
.entry((stream_id, type_id))
.or_default()
.value()
.clone(),
self.clone(),
)
.await?;
self.senders.insert(
(stream_id, task_id, TypeId::of::<M>()),
Arc::new(sender) as _,
);
self.senders
.insert((stream_id, task_id, type_id), Arc::new(sender) as _);
Ok(())
}
async fn send_inner<M: Message>(
pub async fn send<M: Message>(
self: &Arc<Self>,
msg: Result<M, Error>,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
reorder_buff: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let config = if let Some(spawner) = self
.spawners
.read()
.await
.get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{
spawner.config(stream_id)
} else {
Config::default()
};
if reorder_buff > 1 {
self.send_ordered(msg, index, stream_id, config, reorder_buff)
.await
} else {
self.send_inner(msg, index, stream_id, config).await
}
}
pub async fn send_inner<M: Message>(
self: &Arc<Self>,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
config: Config,
@ -182,159 +221,132 @@ impl BusInner {
Ok(())
}
pub async fn send<M: Message>(self: &Arc<Self>, msg: Msg<M>) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let stream_id = msg.stream_id;
let config = if let Some(spawner) = self
.spawners
.read()
.await
.get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{
spawner.config(stream_id)
} else {
Config::default()
};
if config.ordered {
let mut queue = self
.reordering
.entry((stream_id, type_id))
.or_insert_with(|| {
Box::new(ReorderQueue::<Option<Result<M, Error>>>::new(
config.task_count as _,
))
});
let queue = queue
.downcast_mut::<ReorderQueue<Option<Result<M, Error>>>>()
.unwrap();
if let Some(index) = queue.push(msg.index, Some(msg.inner)) {
self.send_inner(
Err::<M, _>(Error::ReorderingMissedMessage(index)),
index,
stream_id,
config,
)
.await?;
}
while let Some(msg) = queue.pop() {
if let (index, Some(Some(msg))) = msg {
self.send_inner(msg, index, stream_id, config).await?;
}
}
Ok(())
} else {
self.send_inner(msg.inner, msg.index, stream_id, config)
.await
}
}
pub async fn send_skip<M: Message>(
pub async fn send_ordered<M: Message>(
self: &Arc<Self>,
stream_id: u32,
msg: Option<Result<M, Error>>,
index: u64,
stream_id: u32,
config: Config,
reorder_buff: u32,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let config = if let Some(spawner) = self
.spawners
.read()
.await
.get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{
spawner.config(stream_id)
} else {
Config::default()
};
let sender = self
.reordering
.entry((stream_id, type_id))
.or_insert_with(|| {
let (tx, rx) = kanal::bounded_async::<Msg<M>>(4);
if config.ordered {
let mut queue = self
.reordering
.entry((stream_id, type_id))
.or_insert_with(|| {
Box::new(ReorderQueue::<Option<Result<M, Error>>>::new(
config.task_count as _,
))
let bus = self.clone();
tokio::spawn(async move {
let mut queue = ReorderQueueInner::new(reorder_buff as _);
while let Ok(msg) = rx.recv().await {
println!(" > {:?}", msg);
if let Some(index) = queue.push(msg.index, msg.inner) {
bus.send_inner(
Some(Err::<M, _>(Error::ReorderingDropMessage(index))),
index,
stream_id,
config,
)
.await
.unwrap();
}
while let Some((index, msg)) = queue.pop() {
if let Err(err) = bus.send_inner(msg, index, stream_id, config).await {
println!("Err: {}", err);
}
}
}
while let Some((index, msg)) = queue.force_pop() {
if let Err(err) = bus.send_inner(msg, index, stream_id, config).await {
println!("Err: {}", err);
}
}
});
let queue = queue
.downcast_mut::<ReorderQueue<Option<Result<M, Error>>>>()
.unwrap();
Arc::new(tx) as _
})
.downgrade()
.clone();
if let Some(index) = queue.push(index, None) {
self.send_inner(
Err::<M, _>(Error::ReorderingMissedMessage(index)),
index,
stream_id,
config,
)
.await?;
}
while let Some(msg) = queue.pop() {
if let (index, Some(Some(msg))) = msg {
self.send_inner(msg, index, stream_id, config).await?;
}
}
}
sender
.upcast()
.downcast_ref::<kanal::AsyncSender<Msg<M>>>()
.unwrap()
.send(Msg {
inner: msg,
index,
stream_id,
})
.await?;
Ok(())
}
pub async fn register_dispatcher<M: Message, B: Builder<M>>(self: Arc<Self>, builder: B)
pub async fn register<M: Message, B: Builder<M>>(self: Arc<Self>, builder: B)
where
B::Context: Handler<M>,
{
let type_id = TypeId::of::<M>();
self.counters.insert(type_id, AtomicU64::new(0));
self.spawners.write().await.insert(
type_id,
Box::new(TaskSpawnerWrapper::from_handler(builder)) as _,
);
}
pub async fn register_producer<M: Message, B: Builder<M>>(self: Arc<Self>, builder: B)
where
B::Context: Producer<M>,
{
self.spawners.write().await.insert(
TypeId::of::<M>(),
Box::new(TaskSpawnerWrapper::from_producer(builder)) as _,
);
}
#[inline]
pub async fn close(&self, force: bool) {
if force {
self.abort_notify.notify_waiters();
}
for item in self.senders.iter() {
if force {
let _ = item.value().terminate();
} else if item.is_producer() {
let _ = item.value().stop().await;
}
}
self.stopping.store(true, Ordering::Relaxed);
self.stopping_notify.notify_waiters();
}
#[inline]
pub async fn wait(&self) {
self.task_counter.wait().await;
while !self.stopping.load(Ordering::Relaxed) {
self.stopping_notify.notified().await;
}
loop {
self.task_counter.wait().await;
for queue in self.reordering.iter() {
queue.value().close();
}
if self.check_stopped() {
break;
}
}
for sender in self.senders.iter() {
let _ = sender.value().stop().await;
}
self.spawn_counter.wait().await;
}
fn check_stopped(&self) -> bool {
println!("Checking Stopped:");
for sender in self.senders.iter() {
let load = sender.value().load();
println!("{:?}: {}/{}", sender.key(), load.0, load.1);
if sender.value().load().0 > 0 {
return false;
}
}
println!("Checking Reordering Queues:");
true
}
}
@ -355,35 +367,13 @@ impl Bus {
where
B::Context: Handler<M>,
{
self.inner.clone().register_dispatcher(builder).await;
self
}
#[inline]
pub async fn register_producer<M: Message, B: Builder<M>>(&self, builder: B) -> &Self
where
B::Context: Producer<M>,
{
self.inner.clone().register_producer(builder).await;
self.inner.clone().register(builder).await;
self
}
#[inline]
pub async fn send<M: Message>(&self, inner: M) -> Result<(), Error> {
let index = self
.inner
.counters
.get(&TypeId::of::<M>())
.map(|x| x.fetch_add(1, Ordering::Relaxed))
.unwrap_or(0);
self.inner
.send(Msg {
inner: Ok(inner),
index,
stream_id: DEFAUL_STREAM_ID,
})
.await
self.send_with_stream(DEFAUL_STREAM_ID, inner).await
}
#[inline]
@ -392,20 +382,16 @@ impl Bus {
stream_id: u32,
inner: M,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let index = self
.inner
.counters
.get(&TypeId::of::<M>())
.map(|x| x.fetch_add(1, Ordering::Relaxed))
.unwrap_or(0);
.entry((stream_id, type_id))
.or_default()
.value()
.fetch_add(1, Ordering::Relaxed);
self.inner
.send(Msg {
inner: Ok(inner),
index,
stream_id,
})
.await
self.inner.send(Some(Ok(inner)), index, stream_id, 0).await
}
///
@ -435,41 +421,32 @@ mod tests {
use std::{sync::Arc, time::Duration};
use async_stream::stream;
use futures::Stream;
use rand::RngCore;
use crate::{
handler::Handler, producer::Producer, Bus, DefaultBuilder, Error, IntoMessage, Message,
SharedBuilder,
handler::Handler, Bus, DefaultBuilder, Error, IntoMessageStream, Message, SharedBuilder,
};
impl Message for u64 {}
impl Message for u32 {}
impl Message for i16 {}
impl Message for u16 {}
#[derive(Default)]
struct TestProducer;
impl Producer<u32> for TestProducer {
type Item = u64;
impl Handler<u32> for TestProducer {
type Result = u64;
fn stream(
async fn handle(
&mut self,
_msg: u32,
_stream_id: u32,
_task_id: u32,
) -> impl Stream<Item = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_ {
stream! {
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
Ok(crate::message::MessageStream(stream! {
for i in 0u64..10 {
yield Ok(i)
}
}
}
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Item>, Error> {
Ok(None)
}))
}
async fn finalize(self) -> Result<(), Error> {
@ -478,10 +455,10 @@ mod tests {
}
}
struct TestConsumer(u32);
struct TestConsumer(u16);
impl Default for TestConsumer {
fn default() -> Self {
Self(rand::thread_rng().next_u32())
Self(rand::thread_rng().next_u32() as _)
}
}
@ -493,7 +470,7 @@ mod tests {
msg: u64,
stream_id: u32,
task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
tokio::time::sleep(Duration::from_millis(1000)).await;
println!(
"[{}] shared consumer handle {}u64 ({}:{})",
@ -501,14 +478,6 @@ mod tests {
);
Ok(())
}
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(None)
}
async fn finalize(self) -> Result<(), Error> {
println!("[{}] shared consumer finalized", self.0);
@ -524,8 +493,8 @@ mod tests {
msg: u64,
stream_id: u32,
task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
tokio::time::sleep(Duration::from_millis(1000)).await;
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
tokio::time::sleep(Duration::from_millis(100)).await;
println!(
"[{}] consumer handle {}u64 ({}:{})",
self.0, msg, stream_id, task_id
@ -533,30 +502,85 @@ mod tests {
Ok(())
}
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(None)
}
async fn finalize(self) -> Result<(), Error> {
println!("[{}] consumer finalized", self.0);
Ok(())
}
}
struct TestHandler {}
impl Handler<i16> for Arc<TestHandler> {
type Result = u16;
async fn handle(
&mut self,
msg: i16,
_stream_id: u32,
task_id: u32,
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
if task_id % 2 == 0 {
tokio::time::sleep(Duration::from_millis(13)).await;
} else {
tokio::time::sleep(Duration::from_millis(22)).await;
}
println!("handle {}", msg);
Ok(msg as u16)
}
async fn finalize(self) -> Result<(), Error> {
Ok(())
}
}
#[derive(Default)]
struct TestCollector {
inner: Vec<u16>,
}
impl Handler<u16> for TestCollector {
type Result = ();
async fn handle(
&mut self,
msg: u16,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
println!("{}", msg);
self.inner.push(msg);
Ok(None)
}
async fn finalize(self) -> Result<(), Error> {
println!("Checking");
assert_eq!(self.inner, (0u16..1024).collect::<Vec<_>>());
Ok(())
}
async fn handle_error(
&mut self,
err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessageStream<Self::Result>, Error> {
println!("{:?}", err);
Ok(None)
}
}
#[tokio::test]
#[ignore = ""]
#[ignore]
async fn test_streams() {
let bus = Bus::default();
bus.register(DefaultBuilder::<u64, TestConsumer>::new(2))
.await;
bus.register_producer(DefaultBuilder::<u32, TestProducer>::new(2))
bus.register(DefaultBuilder::<u32, TestProducer>::new(2))
.await;
for start in 0u32..10 {
@ -568,16 +592,16 @@ mod tests {
}
#[tokio::test]
#[ignore = ""]
#[ignore]
async fn test_tasks_shared() {
let bus = Bus::default();
bus.register(SharedBuilder::new(2, 5, |sid, _tid| async move {
Ok(TestConsumer(sid))
bus.register(SharedBuilder::new(2, 5, |_sid, _tid| async move {
Ok(TestConsumer::default())
}))
.await;
bus.register_producer(DefaultBuilder::<u32, TestProducer>::new(2))
bus.register(DefaultBuilder::<u32, TestProducer>::new(2))
.await;
for start in 0u32..10 {
@ -587,4 +611,29 @@ mod tests {
bus.close().await;
bus.wait().await;
}
#[tokio::test]
// #[ignore]
async fn test_reordering() {
let bus = Bus::default();
bus.register(
SharedBuilder::new(4, 128, |sid, tid| async move {
println!("NEW HANDLER {}/{}", sid, tid);
Ok(TestHandler {})
})
.ordered(),
)
.await;
bus.register(DefaultBuilder::<_, TestCollector>::new(4))
.await;
for i in 0i16..1024 {
bus.send(i).await.unwrap();
}
bus.close().await;
bus.wait().await;
}
}

View File

@ -1,40 +1,70 @@
use core::fmt;
use std::any::Any;
use futures::Stream;
use crate::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Msg<M: Message> {
pub(crate) inner: Result<M, Error>,
pub(crate) inner: Option<Result<M, Error>>,
pub(crate) index: u64,
pub(crate) stream_id: u32,
}
// impl<M: Message> Msg<M> {
// pub(crate) fn new(m: M, index: u64, stream_id: u32) -> Self {
// Self {
// inner: Ok(m),
// index,
// stream_id,
// }
// }
// }
pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {}
pub struct MessageStream<S>(pub S);
pub struct MessageIterator<I>(pub I);
pub struct MessageTryIterator<I>(pub I);
pub trait Message: Any + fmt::Debug + Clone + Send + Sync + 'static {}
impl Message for () {}
pub trait IntoMessage<M>: Send {
fn into_message(self) -> Option<M>;
pub trait IntoMessageStream<M: Message>: Send {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send;
}
impl<M: Message> IntoMessage<M> for Option<M> {
fn into_message(self) -> Option<M> {
self
impl<M: Message, S: Stream<Item = Result<M, Error>> + Send> IntoMessageStream<M>
for MessageStream<S>
{
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
self.0
}
}
impl<M: Message> IntoMessage<M> for M {
fn into_message(self) -> Option<M> {
Some(self)
impl<M: Message, I: Iterator<Item = M> + Send> IntoMessageStream<M> for MessageIterator<I> {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> {
futures::stream::iter(self.0.map(Ok))
}
}
impl<M: Message, I: Iterator<Item = Result<M, Error>> + Send> IntoMessageStream<M>
for MessageTryIterator<I>
{
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
futures::stream::iter(self.0)
}
}
impl<M: Message> IntoMessageStream<M> for Option<M> {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
futures::stream::iter(self.into_iter().map(Ok))
}
}
impl<M: Message> IntoMessageStream<M> for Vec<M> {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
futures::stream::iter(self.into_iter().map(Ok))
}
}
impl<const N: usize, M: Message> IntoMessageStream<M> for [M; N] {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
futures::stream::iter(self.map(Ok))
}
}
impl<M: Message> IntoMessageStream<M> for M {
fn into_message_stream(self) -> impl Stream<Item = Result<M, Error>> + Send {
futures::stream::iter([Ok(self)])
}
}

View File

@ -1,145 +0,0 @@
use std::{
marker::PhantomData,
pin::{pin, Pin},
sync::Arc,
};
use futures::{Future, Stream, StreamExt};
use tokio::sync::Notify;
use crate::{
builder::Builder,
chan::Receiver,
message::Msg,
task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message,
};
pub trait Producer<M: Message>: Send + Sync + 'static {
type Item: Message;
fn stream(
&mut self,
msg: M,
stream_id: u32,
task_id: u32,
) -> impl Stream<Item = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_;
fn handle_error(
&mut self,
err: Error,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_;
fn finalize(self) -> impl Future<Output = Result<(), Error>> + Send;
}
pub(crate) struct ProducerSpawner<M, B> {
pub(crate) builder: B,
_m: PhantomData<M>,
}
impl<M, B> ProducerSpawner<M, B> {
pub(crate) fn new(builder: B) -> Self {
Self {
builder,
_m: PhantomData,
}
}
}
impl<M: Message, B: Builder<M>> TaskSpawner<M> for ProducerSpawner<M, B>
where
B::Context: Producer<M>,
{
fn spawn_task(
&self,
rx: Receiver<Msg<M>>,
stream_id: u32,
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async move {
let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move {
while let Some(recv_msg) = rx.recv().await {
task_counter.inc_running();
match recv_msg.inner {
Ok(msg) => {
let mut stream = pin!(ctx
.stream(msg, stream_id, task_id)
.take_until(abort.notified()));
let mut index = 0;
while let Some(res) = stream.next().await {
if let Some(inner) = res.map(IntoMessage::into_message).transpose()
{
index += 1;
if let Err(err) = bus
.send(Msg {
inner,
index: index - 1,
stream_id,
})
.await
{
println!("BUS SEND ERROR: {:?}", err);
continue;
}
}
}
}
Err(err) => {
let Some(inner) = ctx
.handle_error(err, stream_id, task_id)
.await
.map(IntoMessage::into_message)
.transpose()
else {
if let Err(err) =
bus.send_skip::<M>(stream_id, recv_msg.index).await
{
println!("BUS SEND ERROR: {:?}", err);
}
continue;
};
if let Err(err) = bus
.send(Msg {
inner,
index: recv_msg.index,
stream_id,
})
.await
{
println!("BUS SEND ERROR: {:?}", err);
}
}
}
task_counter.dec_running(rx.is_empty());
}
if let Err(err) = ctx.finalize().await {
println!("TASK FINALIZE ERROR: {:?}", err);
}
});
Ok(())
})
}
fn is_producer(&self) -> bool {
true
}
fn config(&self, stream_id: u32) -> crate::builder::Config {
self.builder.config(stream_id)
}
}

View File

@ -1,7 +1,7 @@
use std::{cmp::Ordering, collections::BinaryHeap};
struct Entry<M> {
inner: Option<M>,
inner: M,
index: u64,
}
@ -25,13 +25,13 @@ impl<M> PartialEq for Entry<M> {
impl<M> Eq for Entry<M> {}
pub(crate) struct ReorderQueue<M> {
pub(crate) struct ReorderQueueInner<M> {
cap: usize,
recent_index: Option<u64>,
heap: BinaryHeap<Entry<M>>,
}
impl<M> ReorderQueue<M> {
impl<M> ReorderQueueInner<M> {
pub fn new(cap: usize) -> Self {
Self {
cap,
@ -40,11 +40,8 @@ impl<M> ReorderQueue<M> {
}
}
pub fn push(&mut self, index: u64, msg: M) -> Option<u64> {
self.heap.push(Entry {
inner: Some(msg),
index,
});
pub fn push(&mut self, index: u64, inner: M) -> Option<u64> {
self.heap.push(Entry { inner, index });
if self.heap.len() == self.cap {
self.recent_index = self.recent_index.map(|x| x + 1);
@ -54,7 +51,7 @@ impl<M> ReorderQueue<M> {
}
}
pub fn pop(&mut self) -> Option<(u64, Option<M>)> {
pub fn pop(&mut self) -> Option<(u64, M)> {
match self.recent_index {
None => {
let e = self.heap.pop()?;
@ -72,22 +69,28 @@ impl<M> ReorderQueue<M> {
}
}
}
pub fn force_pop(&mut self) -> Option<(u64, M)> {
let e = self.heap.pop()?;
self.recent_index = Some(e.index);
Some((e.index, e.inner))
}
}
#[cfg(test)]
mod tests {
use crate::Message;
use super::ReorderQueue;
use super::ReorderQueueInner;
impl Message for i32 {}
#[test]
fn test_reordering() {
let mut queue = ReorderQueue::new(8);
let mut queue = ReorderQueueInner::new(8);
assert_eq!(queue.push(0, 0), None);
assert_eq!(queue.pop(), Some((0, Some(0))));
assert_eq!(queue.pop(), Some((0, 0)));
assert_eq!(queue.pop(), None);
assert_eq!(queue.push(3, 3), None);
@ -100,18 +103,18 @@ mod tests {
assert_eq!(queue.pop(), None);
assert_eq!(queue.push(1, 1), None);
assert_eq!(queue.pop(), Some((1, Some(1))));
assert_eq!(queue.pop(), Some((2, Some(2))));
assert_eq!(queue.pop(), Some((3, Some(3))));
assert_eq!(queue.pop(), Some((4, Some(4))));
assert_eq!(queue.pop(), Some((1, 1)));
assert_eq!(queue.pop(), Some((2, 2)));
assert_eq!(queue.pop(), Some((3, 3)));
assert_eq!(queue.pop(), Some((4, 4)));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_overflow() {
let mut queue = ReorderQueue::new(4);
let mut queue = ReorderQueueInner::new(4);
assert_eq!(queue.push(0, 0), None);
assert_eq!(queue.pop(), Some((0, Some(0))));
assert_eq!(queue.pop(), Some((0, 0)));
assert_eq!(queue.pop(), None);
assert_eq!(queue.push(4, 4), None);
@ -124,10 +127,10 @@ mod tests {
assert_eq!(queue.pop(), None);
assert_eq!(queue.push(5, 5), Some(1));
assert_eq!(queue.pop(), Some((2, Some(2))));
assert_eq!(queue.pop(), Some((3, Some(3))));
assert_eq!(queue.pop(), Some((4, Some(4))));
assert_eq!(queue.pop(), Some((5, Some(5))));
assert_eq!(queue.pop(), Some((2, 2)));
assert_eq!(queue.pop(), Some((3, 3)));
assert_eq!(queue.pop(), Some((4, 4)));
assert_eq!(queue.pop(), Some((5, 5)));
assert_eq!(queue.pop(), None);
}
}

View File

@ -1,7 +1,7 @@
use std::{
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
};
@ -14,29 +14,46 @@ use crate::{
chan::{BusSender, Receiver, Sender},
handler::HandlerSpawner,
message::Msg,
producer::ProducerSpawner,
Builder, BusInner, Error, Handler, Message, Producer,
Builder, BusInner, Error, Handler, Message,
};
#[derive(Default)]
pub(crate) struct TaskCounter {
running: AtomicUsize,
pub(crate) running: AtomicUsize,
notify: Notify,
}
impl TaskCounter {
#[inline]
pub fn inc_running(&self) {
self.running.fetch_add(1, Ordering::Relaxed);
}
pub(crate) struct TaskCounterLease<S: Fn() -> bool> {
need_notify: S,
counter: Arc<TaskCounter>,
}
impl<S: Fn() -> bool> Drop for TaskCounterLease<S> {
fn drop(&mut self) {
let notify = (self.need_notify)();
let prev = self.counter.running.fetch_sub(1, Ordering::Relaxed);
#[inline]
pub fn dec_running(&self, notify: bool) {
let prev = self.running.fetch_sub(1, Ordering::Relaxed);
if notify && prev == 1 {
self.notify.notify_waiters();
self.counter.notify.notify_waiters();
}
}
}
impl<S: Fn() -> bool> TaskCounterLease<S> {
fn new(counter: Arc<TaskCounter>, need_notify: S) -> Self {
counter.running.fetch_add(1, Ordering::Relaxed);
Self {
counter,
need_notify,
}
}
}
impl TaskCounter {
pub fn lease_unit<S: Fn() -> bool>(self: Arc<Self>, need_notify: S) -> TaskCounterLease<S> {
TaskCounterLease::new(self, need_notify)
}
#[inline]
pub async fn wait(&self) {
@ -47,6 +64,8 @@ impl TaskCounter {
pub(crate) trait TaskSpawner<M: Message>: Send + Sync {
fn config(&self, stream_id: u32) -> Config;
fn is_producer(&self) -> bool;
#[allow(clippy::too_many_arguments)]
fn spawn_task(
&self,
rx: Receiver<Msg<M>>,
@ -54,6 +73,8 @@ pub(crate) trait TaskSpawner<M: Message>: Send + Sync {
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
spawn_counter: Arc<TaskCounter>,
index_counter: Arc<AtomicU64>,
bus: Arc<BusInner>,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
}
@ -80,16 +101,8 @@ impl<M: Message> TaskSpawnerWrapper<M> {
}
}
pub fn from_producer<B: Builder<M> + 'static>(builder: B) -> Self
where
B::Context: Producer<M>,
{
Self {
inner: Arc::new(ProducerSpawner::new(builder)) as _,
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub async fn spawn_task(
&self,
(tx, rx): (Sender<Msg<M>>, Receiver<Msg<M>>),
@ -97,10 +110,21 @@ impl<M: Message> TaskSpawnerWrapper<M> {
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
spawn_counter: Arc<TaskCounter>,
index_counter: Arc<AtomicU64>,
bus: Arc<BusInner>,
) -> Result<BusSender<M>, Error> {
self.inner
.spawn_task(rx, stream_id, task_id, abort, task_counter, bus)
.spawn_task(
rx,
stream_id,
task_id,
abort,
task_counter,
spawn_counter,
index_counter,
bus,
)
.await?;
Ok(BusSender::new(self.inner.is_producer(), tx))