tasks shared queue

This commit is contained in:
Andrey Tkachenko 2023-10-27 12:57:09 +04:00
parent 2ae4e8bfed
commit 46b63aa1f8
6 changed files with 198 additions and 95 deletions

View File

@ -4,28 +4,57 @@ use futures::Future;
use crate::{Error, Message};
#[derive(Clone, Copy, Debug)]
pub struct Config {
pub queue_size: usize,
pub queue_per_task: bool,
pub ordered: bool,
pub task_count: u32,
pub lazy_task_creation: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
queue_size: 4,
queue_per_task: false,
ordered: false,
task_count: 1,
lazy_task_creation: true,
}
}
}
pub trait Builder<M: Message>: Send + Sync + 'static {
type Context: 'static;
type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a
where
Self: 'a;
fn parallel(&self, _stream_id: u32) -> (u32, bool) {
(1, false)
}
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize {
4
fn config(&self, _stream_id: u32) -> Config {
Default::default()
}
fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_>;
}
pub struct DefaultBuilder<M: Message, H>(usize, PhantomData<(M, H)>);
pub struct DefaultBuilder<M: Message, H> {
config: Config,
_m: PhantomData<(M, H)>,
}
impl<M: Message, H> DefaultBuilder<M, H> {
pub fn new(queue_size: usize) -> Self {
Self(queue_size, Default::default())
Self {
config: Config {
queue_size,
queue_per_task: false,
ordered: false,
task_count: 1,
lazy_task_creation: true,
},
_m: PhantomData,
}
}
}
@ -37,17 +66,15 @@ impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuild
async move { Ok(<Self::Context as Default>::default()) }
}
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize {
self.0
fn config(&self, _stream_id: u32) -> Config {
self.config
}
}
pub struct SharedBuilder<M, H, C, F> {
queue_size: usize,
parallel: u32,
config: Config,
stream_handlers: dashmap::DashMap<u32, Arc<H>>,
callback: C,
ordered: bool,
_m: PhantomData<(M, F)>,
}
@ -58,13 +85,41 @@ where
F: Sync + Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static,
{
pub fn new(queue_size: usize, parallel: u32, ordered: bool, callback: C) -> Self {
pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self {
Self {
queue_size,
parallel,
config: Config {
queue_size,
queue_per_task: false,
ordered: false,
task_count,
lazy_task_creation: true,
},
stream_handlers: Default::default(),
callback,
ordered,
_m: PhantomData,
}
}
pub fn ordered(self) -> Self {
let mut config = self.config;
config.ordered = true;
Self {
config,
stream_handlers: self.stream_handlers,
callback: self.callback,
_m: PhantomData,
}
}
pub fn queue_per_task(self) -> Self {
let mut config = self.config;
config.queue_per_task = true;
Self {
config,
stream_handlers: self.stream_handlers,
callback: self.callback,
_m: PhantomData,
}
}
@ -92,11 +147,7 @@ where
}
}
fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize {
self.queue_size
}
fn parallel(&self, _stream_id: u32) -> (u32, bool) {
(self.parallel, self.ordered)
fn config(&self, _stream_id: u32) -> Config {
self.config
}
}

View File

@ -58,6 +58,7 @@ impl<M: Message> BusSender<M> {
}
}
#[derive(Clone)]
pub(crate) struct Sender<T> {
inner: kanal::AsyncSender<ChannelItem<T>>,
}
@ -83,6 +84,7 @@ impl<T> Sender<T> {
}
}
#[derive(Clone)]
pub(crate) struct Receiver<T> {
inner: kanal::AsyncReceiver<ChannelItem<T>>,
}

View File

@ -1,6 +1,7 @@
use std::{
any::{Any, TypeId},
marker::PhantomData,
pin::Pin,
sync::Arc,
};
@ -9,7 +10,7 @@ use tokio::sync::Notify;
use crate::{
builder::Builder,
chan::{channel, Sender},
chan::Receiver,
message::Msg,
task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message,
@ -50,15 +51,15 @@ where
{
fn spawn_task(
&self,
rx: Receiver<Msg<M>>,
stream_id: u32,
task_id: u32,
_abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_> {
Box::new(async move {
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async move {
let bus = bus.clone();
let (tx, rx) = channel::<Msg<M>>(self.builder.queue_size(stream_id, task_id));
let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move {
@ -96,7 +97,7 @@ where
println!("TASK FINALIZE ERROR: {:?}", err);
}
});
Ok(tx)
Ok(())
})
}
@ -104,7 +105,7 @@ where
false
}
fn parallel(&self, stream_id: u32) -> (u32, bool) {
self.builder.parallel(stream_id)
fn config(&self, stream_id: u32) -> crate::builder::Config {
self.builder.config(stream_id)
}
}

View File

@ -19,7 +19,8 @@ use std::{
},
};
use chan::{BusSender, BusSenderClose};
use builder::Config;
use chan::{channel, BusSender, BusSenderClose, Receiver, Sender};
use dashmap::DashMap;
use futures::Future;
use message::Msg;
@ -47,13 +48,13 @@ struct BusInner {
}
impl BusInner {
fn get_task_id<M: Message>(&self, stream_id: u32, task_count: u32) -> u32 {
if task_count == 1 {
fn get_task_id<M: Message>(&self, stream_id: u32, config: &Config) -> u32 {
if !config.queue_per_task || config.task_count == 1 {
return DEFAUL_TASK_ID;
}
let type_id = TypeId::of::<M>();
let (id1, id2) = self.rng.next_u32_pair(task_count);
let (id1, id2) = self.rng.next_u32_pair(config.task_count);
let Some(l1) = self
.senders
@ -80,14 +81,40 @@ impl BusInner {
}
}
async fn spawn_task<M: Message>(
self: &Arc<Self>,
spawner: &TaskSpawnerWrapper<M>,
chan: (Sender<Msg<M>>, Receiver<Msg<M>>),
stream_id: u32,
task_id: u32,
) -> Result<(), Error> {
let sender = spawner
.spawn_task(
chan,
stream_id,
task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.clone(),
)
.await?;
self.senders.insert(
(stream_id, task_id, TypeId::of::<M>()),
Arc::new(sender) as _,
);
Ok(())
}
async fn send_inner<M: Message>(
self: &Arc<Self>,
msg: Msg<M>,
task_count: u32,
config: Config,
) -> Result<(), Error> {
let type_id = TypeId::of::<M>();
let stream_id = msg.stream_id;
let task_id = self.get_task_id::<M>(stream_id, task_count);
let task_id = self.get_task_id::<M>(stream_id, &config);
if !self.senders.contains_key(&(stream_id, task_id, type_id)) {
let spawner = if let Some(spawner) = self.spawners.read().await.get(&type_id) {
@ -99,36 +126,51 @@ impl BusInner {
return Err(Error::HandlerIsNotRegistered);
};
let tx = spawner
.spawn_task(
stream_id,
task_id,
self.abort_notify.clone(),
self.task_counter.clone(),
self.clone(),
)
.await?;
tx.send(msg).await.unwrap();
self.senders
.insert((stream_id, task_id, type_id), Arc::new(tx) as _);
} else {
let senders = self
.senders
.get(&(stream_id, task_id, type_id))
.unwrap()
.clone();
senders
.upcast()
.downcast_ref::<BusSender<M>>()
.unwrap()
.send(msg)
.await
.unwrap();
if config.queue_per_task {
if config.lazy_task_creation {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id,
task_id,
)
.await?;
} else {
for tid in 0..config.task_count {
self.spawn_task(
&spawner,
channel::<Msg<M>>(config.queue_size),
stream_id,
tid,
)
.await?;
}
}
} else {
let (tx, rx) = channel::<Msg<M>>(config.queue_size);
for tid in 0..config.task_count - 1 {
self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid)
.await?;
}
self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1)
.await?;
};
}
let senders = self
.senders
.get(&(stream_id, task_id, type_id))
.unwrap()
.clone();
senders
.upcast()
.downcast_ref::<BusSender<M>>()
.unwrap()
.send(msg)
.await
.unwrap();
Ok(())
}
@ -136,19 +178,19 @@ impl BusInner {
let type_id = TypeId::of::<M>();
let stream_id = msg.stream_id;
let (task_count, ordered) = if let Some(spawner) = self
let config = if let Some(spawner) = self
.spawners
.read()
.await
.get(&type_id)
.and_then(|x| x.downcast_ref::<TaskSpawnerWrapper<M>>())
{
spawner.parallel(stream_id)
spawner.config(stream_id)
} else {
(1, false)
Config::default()
};
if ordered {
if config.ordered {
// let queue = self
// .reordering_queue
// .get_or_insert(&(stream_id, type_id), task_count);
@ -159,9 +201,9 @@ impl BusInner {
// }
// Ok(())
self.send_inner(msg, task_count).await
self.send_inner(msg, config).await
} else {
self.send_inner(msg, task_count).await
self.send_inner(msg, config).await
}
}
@ -419,7 +461,7 @@ mod tests {
async fn test_tasks_shared() {
let bus = Bus::default();
bus.register(SharedBuilder::new(2, 5, false, |sid, _tid| async move {
bus.register(SharedBuilder::new(2, 5, |sid, _tid| async move {
Ok(TestConsumer(sid))
}))
.await;

View File

@ -1,11 +1,15 @@
use std::{marker::PhantomData, pin::pin, sync::Arc};
use std::{
marker::PhantomData,
pin::{pin, Pin},
sync::Arc,
};
use futures::{Future, Stream, StreamExt};
use tokio::sync::Notify;
use crate::{
builder::Builder,
chan::{channel, Sender},
chan::Receiver,
message::Msg,
task::{TaskCounter, TaskSpawner},
BusInner, Error, IntoMessage, Message,
@ -46,14 +50,14 @@ where
{
fn spawn_task(
&self,
rx: Receiver<Msg<M>>,
stream_id: u32,
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_> {
Box::new(async move {
let (tx, rx) = channel::<Msg<M>>(self.builder.queue_size(stream_id, task_id));
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async move {
let mut ctx = self.builder.build(stream_id, task_id).await?;
let _handle = tokio::spawn(async move {
@ -102,7 +106,7 @@ where
}
});
Ok(tx)
Ok(())
})
}
@ -110,7 +114,7 @@ where
true
}
fn parallel(&self, stream_id: u32) -> (u32, bool) {
self.builder.parallel(stream_id)
fn config(&self, stream_id: u32) -> crate::builder::Config {
self.builder.config(stream_id)
}
}

View File

@ -1,13 +1,17 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
use std::{
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use futures::Future;
use tokio::sync::Notify;
use crate::{
chan::{BusSender, Sender},
builder::Config,
chan::{BusSender, Receiver, Sender},
handler::HandlerSpawner,
message::Msg,
producer::ProducerSpawner,
@ -41,16 +45,17 @@ impl TaskCounter {
}
pub(crate) trait TaskSpawner<M: Message>: Send + Sync {
fn parallel(&self, stream_id: u32) -> (u32, bool);
fn config(&self, stream_id: u32) -> Config;
fn is_producer(&self) -> bool;
fn spawn_task(
&self,
rx: Receiver<Msg<M>>,
stream_id: u32,
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>,
) -> Box<dyn Future<Output = Result<Sender<Msg<M>>, Error>> + Send + '_>;
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
}
pub(crate) struct TaskSpawnerWrapper<M: Message> {
@ -87,24 +92,22 @@ impl<M: Message> TaskSpawnerWrapper<M> {
#[inline]
pub async fn spawn_task(
&self,
(tx, rx): (Sender<Msg<M>>, Receiver<Msg<M>>),
stream_id: u32,
task_id: u32,
abort: Arc<Notify>,
task_counter: Arc<TaskCounter>,
bus: Arc<BusInner>,
) -> Result<BusSender<M>, Error> {
Ok(BusSender::new(
self.inner.is_producer(),
Box::into_pin(
self.inner
.spawn_task(stream_id, task_id, abort, task_counter, bus),
)
.await?,
))
self.inner
.spawn_task(rx, stream_id, task_id, abort, task_counter, bus)
.await?;
Ok(BusSender::new(self.inner.is_producer(), tx))
}
#[inline]
pub(crate) fn parallel(&self, stream_id: u32) -> (u32, bool) {
self.inner.parallel(stream_id)
pub(crate) fn config(&self, stream_id: u32) -> Config {
self.inner.config(stream_id)
}
}