tasks shared queue

This commit is contained in:
Andrey Tkachenko 2023-10-27 12:57:09 +04:00
parent 2ae4e8bfed
commit 46b63aa1f8
6 changed files with 198 additions and 95 deletions

View File

@ -4,28 +4,57 @@ use futures::Future;
use crate::{Error, Message}; use crate::{Error, Message};
#[derive(Clone, Copy, Debug)]
pub struct Config {
pub queue_size: usize,
pub queue_per_task: bool,
pub ordered: bool,
pub task_count: u32,
pub lazy_task_creation: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
queue_size: 4,
queue_per_task: false,
ordered: false,
task_count: 1,
lazy_task_creation: true,
}
}
}
pub trait Builder<M: Message>: Send + Sync + 'static { pub trait Builder<M: Message>: Send + Sync + 'static {
type Context: 'static; type Context: 'static;
type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a
where where
Self: 'a; Self: 'a;
fn parallel(&self, _stream_id: u32) -> (u32, bool) { fn config(&self, _stream_id: u32) -> Config {
(1, false) Default::default()
}
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize {
4
} }
fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_>; fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_>;
} }
pub struct DefaultBuilder<M: Message, H>(usize, PhantomData<(M, H)>); pub struct DefaultBuilder<M: Message, H> {
config: Config,
_m: PhantomData<(M, H)>,
}
impl<M: Message, H> DefaultBuilder<M, H> { impl<M: Message, H> DefaultBuilder<M, H> {
pub fn new(queue_size: usize) -> Self { pub fn new(queue_size: usize) -> Self {
Self(queue_size, Default::default()) Self {
config: Config {
queue_size,
queue_per_task: false,
ordered: false,
task_count: 1,
lazy_task_creation: true,
},
_m: PhantomData,
}
} }
} }
@ -37,17 +66,15 @@ impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuild
async move { Ok(<Self::Context as Default>::default()) } async move { Ok(<Self::Context as Default>::default()) }
} }
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize { fn config(&self, _stream_id: u32) -> Config {
self.0 self.config
} }
} }
pub struct SharedBuilder<M, H, C, F> { pub struct SharedBuilder<M, H, C, F> {
queue_size: usize, config: Config,
parallel: u32,
stream_handlers: dashmap::DashMap<u32, Arc<H>>, stream_handlers: dashmap::DashMap<u32, Arc<H>>,
callback: C, callback: C,
ordered: bool,
_m: PhantomData<(M, F)>, _m: PhantomData<(M, F)>,
} }
@ -58,13 +85,41 @@ where
F: Sync + Send + Future<Output = Result<H, Error>> + 'static, F: Sync + Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static,
{ {
pub fn new(queue_size: usize, parallel: u32, ordered: bool, callback: C) -> Self { pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self {
Self { Self {
config: Config {
queue_size, queue_size,
parallel, queue_per_task: false,
ordered: false,
task_count,
lazy_task_creation: true,
},
stream_handlers: Default::default(), stream_handlers: Default::default(),
callback, callback,
ordered, _m: PhantomData,
}
}
pub fn ordered(self) -> Self {
let mut config = self.config;
config.ordered = true;
Self {
config,
stream_handlers: self.stream_handlers,
callback: self.callback,
_m: PhantomData,
}
}
pub fn queue_per_task(self) -> Self {
let mut config = self.config;
config.queue_per_task = true;
Self {
config,
stream_handlers: self.stream_handlers,
callback: self.callback,
_m: PhantomData, _m: PhantomData,
} }
} }
@ -92,11 +147,7 @@ where
} }
} }
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize { fn config(&self, _stream_id: u32) -> Config {
self.queue_size self.config
}
fn parallel(&self, _stream_id: u32) -> (u32, bool) {
(self.parallel, self.ordered)
} }
} }

View File

@ -58,6 +58,7 @@ impl<M: Message> BusSender<M> {
} }
} }
#[derive(Clone)]
pub(crate) struct Sender<T> { pub(crate) struct Sender<T> {
inner: kanal::AsyncSender<ChannelItem<T>>, inner: kanal::AsyncSender<ChannelItem<T>>,
} }
@ -83,6 +84,7 @@ impl<T> Sender<T> {
} }
} }
#[derive(Clone)]
pub(crate) struct Receiver<T> { pub(crate) struct Receiver<T> {
inner: kanal::AsyncReceiver<ChannelItem<T>>, inner: kanal::AsyncReceiver<ChannelItem<T>>,
} }

View File

@ -1,6 +1,7 @@
use std::{ use std::{
any::{Any, TypeId}, any::{Any, TypeId},
marker::PhantomData, marker::PhantomData,
pin::Pin,
sync::Arc, sync::Arc,
}; };
@ -9,7 +10,7 @@ use tokio::sync::Notify;
use crate::{ use crate::{
builder::Builder, builder::Builder,
chan::{channel, Sender}, chan::Receiver,
message::Msg, message::Msg,
task::{TaskCounter, TaskSpawner}, task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message, BusInner, Error, IntoMessage, Message,
@ -50,15 +51,15 @@ where
{ {
fn spawn_task( fn spawn_task(
&self, &self,
rx: Receiver<Msg<M>>,
stream_id: u32, stream_id: u32,
task_id: u32, task_id: u32,
_abort: Arc<Notify>, _abort: Arc<Notify>,
task_counter: Arc<TaskCounter>, task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>, bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_> { ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::new(async move { Box::pin(async move {
let bus = bus.clone(); let bus = bus.clone();
let (tx, rx) = channel::<Msg<M>>(self.builder.queue_size(stream_id, task_id));
let mut ctx = self.builder.build(stream_id, task_id).await?; let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move { let _handle = tokio::spawn(async move {
@ -96,7 +97,7 @@ where
println!("TASK FINALIZE ERROR: {:?}", err); println!("TASK FINALIZE ERROR: {:?}", err);
} }
}); });
Ok(tx) Ok(())
}) })
} }
@ -104,7 +105,7 @@ where
false false
} }
fn parallel(&self, stream_id: u32) -> (u32, bool) { fn config(&self, stream_id: u32) -> crate::builder::Config {
self.builder.parallel(stream_id) self.builder.config(stream_id)
} }
} }

View File

@ -19,7 +19,8 @@ use std::{
}, },
}; };
use chan::{BusSender, BusSenderClose}; use builder::Config;
use chan::{channel, BusSender, BusSenderClose, Receiver, Sender};
use dashmap::DashMap; use dashmap::DashMap;
use futures::Future; use futures::Future;
use message::Msg; use message::Msg;
@ -47,13 +48,13 @@ struct BusInner {
} }
impl BusInner { impl BusInner {
fn get_task_id<M: Message>(&self, stream_id: u32, task_count: u32) -> u32 { fn get_task_id<M: Message>(&self, stream_id: u32, config: &Config) -> u32 {
if task_count == 1 { if !config.queue_per_task || config.task_count == 1 {
return DEFAUL_TASK_ID; return DEFAUL_TASK_ID;
} }
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
let (id1, id2) = self.rng.next_u32_pair(task_count); let (id1, id2) = self.rng.next_u32_pair(config.task_count);
let Some(l1) = self let Some(l1) = self
.senders .senders
@ -80,14 +81,40 @@ impl BusInner {
} }
} }
async fn spawn_task<M: Message>(
self: &Arc<Self>,
spawner: &TaskSpawnerWrapper<M>,
chan: (Sender<Msg<M>>, Receiver<Msg<M>>),
stream_id: u32,
task_id: u32,
) -> Result<(), Error> {
let sender = spawner
.spawn_task(
chan,
stream_id,
task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.clone(),
)
.await?;
self.senders.insert(
(stream_id, task_id, TypeId::of::<M>()),
Arc::new(sender) as _,
);
Ok(())
}
async fn send_inner<M: Message>( async fn send_inner<M: Message>(
self: &Arc<Self>, self: &Arc<Self>,
msg: Msg<M>, msg: Msg<M>,
task_count: u32, config: Config,
) -> Result<(), Error> { ) -> Result<(), Error> {
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
let stream_id = msg.stream_id; let stream_id = msg.stream_id;
let task_id = self.get_task_id::<M>(stream_id, task_count); let task_id = self.get_task_id::<M>(stream_id, &config);
if !self.senders.contains_key(&(stream_id, task_id, type_id)) { if !self.senders.contains_key(&(stream_id, task_id, type_id)) {
let spawner = if let Some(spawner) = self.spawners.read().await.get(&type_id) { let spawner = if let Some(spawner) = self.spawners.read().await.get(&type_id) {
@ -99,21 +126,37 @@ impl BusInner {
return Err(Error::HandlerIsNotRegistered); return Err(Error::HandlerIsNotRegistered);
}; };
let tx = spawner if config.queue_per_task {
.spawn_task( if config.lazy_task_creation {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id, stream_id,
task_id, task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.clone(),
) )
.await?; .await?;
tx.send(msg).await.unwrap();
self.senders
.insert((stream_id, task_id, type_id), Arc::new(tx) as _);
} else { } else {
for tid in 0..config.task_count {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id,
tid,
)
.await?;
}
}
} else {
let (tx, rx) = channel::<Msg<M>>(config.queue_size);
for tid in 0..config.task_count - 1 {
self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid)
.await?;
}
self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1)
.await?;
};
}
let senders = self let senders = self
.senders .senders
.get(&(stream_id, task_id, type_id)) .get(&(stream_id, task_id, type_id))
@ -127,7 +170,6 @@ impl BusInner {
.send(msg) .send(msg)
.await .await
.unwrap(); .unwrap();
}
Ok(()) Ok(())
} }
@ -136,19 +178,19 @@ impl BusInner {
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
let stream_id = msg.stream_id; let stream_id = msg.stream_id;
let (task_count, ordered) = if let Some(spawner) = self let config = if let Some(spawner) = self
.spawners .spawners
.read() .read()
.await .await
.get(&type_id) .get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>()) .and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{ {
spawner.parallel(stream_id) spawner.config(stream_id)
} else { } else {
(1, false) Config::default()
}; };
if ordered { if config.ordered {
// let queue = self // let queue = self
// .reordering_queue // .reordering_queue
// .get_or_insert(&(stream_id, type_id), task_count); // .get_or_insert(&(stream_id, type_id), task_count);
@ -159,9 +201,9 @@ impl BusInner {
// } // }
// Ok(()) // Ok(())
self.send_inner(msg, task_count).await self.send_inner(msg, config).await
} else { } else {
self.send_inner(msg, task_count).await self.send_inner(msg, config).await
} }
} }
@ -419,7 +461,7 @@ mod tests {
async fn test_tasks_shared() { async fn test_tasks_shared() {
let bus = Bus::default(); let bus = Bus::default();
bus.register(SharedBuilder::new(2, 5, false, |sid, _tid| async move { bus.register(SharedBuilder::new(2, 5, |sid, _tid| async move {
Ok(TestConsumer(sid)) Ok(TestConsumer(sid))
})) }))
.await; .await;

View File

@ -1,11 +1,15 @@
use std::{marker::PhantomData, pin::pin, sync::Arc}; use std::{
marker::PhantomData,
pin::{pin, Pin},
sync::Arc,
};
use futures::{Future, Stream, StreamExt}; use futures::{Future, Stream, StreamExt};
use tokio::sync::Notify; use tokio::sync::Notify;
use crate::{ use crate::{
builder::Builder, builder::Builder,
chan::{channel, Sender}, chan::Receiver,
message::Msg, message::Msg,
task::{TaskCounter, TaskSpawner}, task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message, BusInner, Error, IntoMessage, Message,
@ -46,14 +50,14 @@ where
{ {
fn spawn_task( fn spawn_task(
&self, &self,
rx: Receiver<Msg<M>>,
stream_id: u32, stream_id: u32,
task_id: u32, task_id: u32,
abort: Arc<Notify>, abort: Arc<Notify>,
task_counter: Arc<TaskCounter>, task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>, bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_> { ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::new(async move { Box::pin(async move {
let (tx, rx) = channel::<Msg<M>>(self.builder.queue_size(stream_id, task_id));
let mut ctx = self.builder.build(stream_id, task_id).await?; let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move { let _handle = tokio::spawn(async move {
@ -102,7 +106,7 @@ where
} }
}); });
Ok(tx) Ok(())
}) })
} }
@ -110,7 +114,7 @@ where
true true
} }
fn parallel(&self, stream_id: u32) -> (u32, bool) { fn config(&self, stream_id: u32) -> crate::builder::Config {
self.builder.parallel(stream_id) self.builder.config(stream_id)
} }
} }

View File

@ -1,13 +1,17 @@
use std::sync::{ use std::{
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Arc,
},
}; };
use futures::Future; use futures::Future;
use tokio::sync::Notify; use tokio::sync::Notify;
use crate::{ use crate::{
chan::{BusSender, Sender}, builder::Config,
chan::{BusSender, Receiver, Sender},
handler::HandlerSpawner, handler::HandlerSpawner,
message::Msg, message::Msg,
producer::ProducerSpawner, producer::ProducerSpawner,
@ -41,16 +45,17 @@ impl TaskCounter {
} }
pub(crate) trait TaskSpawner<M: Message>: Send + Sync { pub(crate) trait TaskSpawner<M: Message>: Send + Sync {
fn parallel(&self, stream_id: u32) -> (u32, bool); fn config(&self, stream_id: u32) -> Config;
fn is_producer(&self) -> bool; fn is_producer(&self) -> bool;
fn spawn_task( fn spawn_task(
&self, &self,
rx: Receiver<Msg<M>>,
stream_id: u32, stream_id: u32,
task_id: u32, task_id: u32,
abort: Arc<Notify>, abort: Arc<Notify>,
task_counter: Arc<TaskCounter>, task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>, bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_>; ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
} }
pub(crate) struct TaskSpawnerWrapper<M: Message> { pub(crate) struct TaskSpawnerWrapper<M: Message> {
@ -87,24 +92,22 @@ impl<M: Message> TaskSpawnerWrapper<M> {
#[inline] #[inline]
pub async fn spawn_task( pub async fn spawn_task(
&self, &self,
(tx, rx): (Sender<Msg<M>>, Receiver<Msg<M>>),
stream_id: u32, stream_id: u32,
task_id: u32, task_id: u32,
abort: Arc<Notify>, abort: Arc<Notify>,
task_counter: Arc<TaskCounter>, task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>, bus: Arc<BusInner>,
) -> Result<BusSender<M>, Error> { ) -> Result<BusSender<M>, Error> {
Ok(BusSender::new(
self.inner.is_producer(),
Box::into_pin(
self.inner self.inner
.spawn_task(stream_id, task_id, abort, task_counter, bus), .spawn_task(rx, stream_id, task_id, abort, task_counter, bus)
) .await?;
.await?,
)) Ok(BusSender::new(self.inner.is_producer(), tx))
} }
#[inline] #[inline]
pub(crate) fn parallel(&self, stream_id: u32) -> (u32, bool) { pub(crate) fn config(&self, stream_id: u32) -> Config {
self.inner.parallel(stream_id) self.inner.config(stream_id)
} }
} }