diff --git a/src/builder.rs b/src/builder.rs index a8b79aa..ec01fe8 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -4,28 +4,57 @@ use futures::Future; use crate::{Error, Message}; +#[derive(Clone, Copy, Debug)] +pub struct Config { + pub queue_size: usize, + pub queue_per_task: bool, + pub ordered: bool, + pub task_count: u32, + pub lazy_task_creation: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + queue_size: 4, + queue_per_task: false, + ordered: false, + task_count: 1, + lazy_task_creation: true, + } + } +} + pub trait Builder: Send + Sync + 'static { type Context: 'static; type BuildFut<'a>: Future> + Send + 'a where Self: 'a; - fn parallel(&self, _stream_id: u32) -> (u32, bool) { - (1, false) - } - - fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize { - 4 + fn config(&self, _stream_id: u32) -> Config { + Default::default() } fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_>; } -pub struct DefaultBuilder(usize, PhantomData<(M, H)>); +pub struct DefaultBuilder { + config: Config, + _m: PhantomData<(M, H)>, +} impl DefaultBuilder { pub fn new(queue_size: usize) -> Self { - Self(queue_size, Default::default()) + Self { + config: Config { + queue_size, + queue_per_task: false, + ordered: false, + task_count: 1, + lazy_task_creation: true, + }, + _m: PhantomData, + } } } @@ -37,17 +66,15 @@ impl Builder for DefaultBuild async move { Ok(::default()) } } - fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize { - self.0 + fn config(&self, _stream_id: u32) -> Config { + self.config } } pub struct SharedBuilder { - queue_size: usize, - parallel: u32, + config: Config, stream_handlers: dashmap::DashMap>, callback: C, - ordered: bool, _m: PhantomData<(M, F)>, } @@ -58,13 +85,41 @@ where F: Sync + Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { - pub fn new(queue_size: usize, parallel: u32, ordered: bool, callback: C) -> Self { + pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self { Self { - queue_size, - parallel, + config: Config { + queue_size, + queue_per_task: false, + ordered: false, + task_count, + lazy_task_creation: true, + }, stream_handlers: Default::default(), callback, - ordered, + _m: PhantomData, + } + } + + pub fn ordered(self) -> Self { + let mut config = self.config; + config.ordered = true; + + Self { + config, + stream_handlers: self.stream_handlers, + callback: self.callback, + _m: PhantomData, + } + } + + pub fn queue_per_task(self) -> Self { + let mut config = self.config; + config.queue_per_task = true; + + Self { + config, + stream_handlers: self.stream_handlers, + callback: self.callback, _m: PhantomData, } } @@ -92,11 +147,7 @@ where } } - fn queue_size(&self, _stream_id: u32, _task_id: u32) -> usize { - self.queue_size - } - - fn parallel(&self, _stream_id: u32) -> (u32, bool) { - (self.parallel, self.ordered) + fn config(&self, _stream_id: u32) -> Config { + self.config } } diff --git a/src/chan.rs b/src/chan.rs index a29eaaf..20924ab 100644 --- a/src/chan.rs +++ b/src/chan.rs @@ -58,6 +58,7 @@ impl BusSender { } } +#[derive(Clone)] pub(crate) struct Sender { inner: kanal::AsyncSender>, } @@ -83,6 +84,7 @@ impl Sender { } } +#[derive(Clone)] pub(crate) struct Receiver { inner: kanal::AsyncReceiver>, } diff --git a/src/handler.rs b/src/handler.rs index ab22546..64aeb6e 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,6 +1,7 @@ use std::{ any::{Any, TypeId}, marker::PhantomData, + pin::Pin, sync::Arc, }; @@ -9,7 +10,7 @@ use tokio::sync::Notify; use crate::{ builder::Builder, - chan::{channel, Sender}, + chan::Receiver, message::Msg, task::{TaskCounter, TaskSpawner}, BusInner, Error, IntoMessage, Message, @@ -50,15 +51,15 @@ where { fn spawn_task( &self, + rx: Receiver>, stream_id: u32, task_id: u32, _abort: Arc, task_counter: Arc, bus: Arc, - ) -> Box>, Error>> + Send + '_> { - Box::new(async move { + ) -> Pin> + Send + '_>> { + Box::pin(async move { let bus = bus.clone(); - let (tx, rx) = channel::>(self.builder.queue_size(stream_id, task_id)); let mut ctx = self.builder.build(stream_id, task_id).await?; let _handle = tokio::spawn(async move { @@ -96,7 +97,7 @@ where println!("TASK FINALIZE ERROR: {:?}", err); } }); - Ok(tx) + Ok(()) }) } @@ -104,7 +105,7 @@ where false } - fn parallel(&self, stream_id: u32) -> (u32, bool) { - self.builder.parallel(stream_id) + fn config(&self, stream_id: u32) -> crate::builder::Config { + self.builder.config(stream_id) } } diff --git a/src/lib.rs b/src/lib.rs index 22801f8..e491716 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,8 @@ use std::{ }, }; -use chan::{BusSender, BusSenderClose}; +use builder::Config; +use chan::{channel, BusSender, BusSenderClose, Receiver, Sender}; use dashmap::DashMap; use futures::Future; use message::Msg; @@ -47,13 +48,13 @@ struct BusInner { } impl BusInner { - fn get_task_id(&self, stream_id: u32, task_count: u32) -> u32 { - if task_count == 1 { + fn get_task_id(&self, stream_id: u32, config: &Config) -> u32 { + if !config.queue_per_task || config.task_count == 1 { return DEFAUL_TASK_ID; } let type_id = TypeId::of::(); - let (id1, id2) = self.rng.next_u32_pair(task_count); + let (id1, id2) = self.rng.next_u32_pair(config.task_count); let Some(l1) = self .senders @@ -80,14 +81,40 @@ impl BusInner { } } + async fn spawn_task( + self: &Arc, + spawner: &TaskSpawnerWrapper, + chan: (Sender>, Receiver>), + stream_id: u32, + task_id: u32, + ) -> Result<(), Error> { + let sender = spawner + .spawn_task( + chan, + stream_id, + task_id, + self.abort_notify.clone(), + self.task_counter.clone(), + self.clone(), + ) + .await?; + + self.senders.insert( + (stream_id, task_id, TypeId::of::()), + Arc::new(sender) as _, + ); + + Ok(()) + } + async fn send_inner( self: &Arc, msg: Msg, - task_count: u32, + config: Config, ) -> Result<(), Error> { let type_id = TypeId::of::(); let stream_id = msg.stream_id; - let task_id = self.get_task_id::(stream_id, task_count); + let task_id = self.get_task_id::(stream_id, &config); if !self.senders.contains_key(&(stream_id, task_id, type_id)) { let spawner = if let Some(spawner) = self.spawners.read().await.get(&type_id) { @@ -99,36 +126,51 @@ impl BusInner { return Err(Error::HandlerIsNotRegistered); }; - let tx = spawner - .spawn_task( - stream_id, - task_id, - self.abort_notify.clone(), - self.task_counter.clone(), - self.clone(), - ) - .await?; - - tx.send(msg).await.unwrap(); - - self.senders - .insert((stream_id, task_id, type_id), Arc::new(tx) as _); - } else { - let senders = self - .senders - .get(&(stream_id, task_id, type_id)) - .unwrap() - .clone(); - - senders - .upcast() - .downcast_ref::>() - .unwrap() - .send(msg) - .await - .unwrap(); + if config.queue_per_task { + if config.lazy_task_creation { + self.spawn_task( + &spawner, + channel::>(config.queue_size), + stream_id, + task_id, + ) + .await?; + } else { + for tid in 0..config.task_count { + self.spawn_task( + &spawner, + channel::>(config.queue_size), + stream_id, + tid, + ) + .await?; + } + } + } else { + let (tx, rx) = channel::>(config.queue_size); + for tid in 0..config.task_count - 1 { + self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid) + .await?; + } + self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1) + .await?; + }; } + let senders = self + .senders + .get(&(stream_id, task_id, type_id)) + .unwrap() + .clone(); + + senders + .upcast() + .downcast_ref::>() + .unwrap() + .send(msg) + .await + .unwrap(); + Ok(()) } @@ -136,19 +178,19 @@ impl BusInner { let type_id = TypeId::of::(); let stream_id = msg.stream_id; - let (task_count, ordered) = if let Some(spawner) = self + let config = if let Some(spawner) = self .spawners .read() .await .get(&type_id) .and_then(|x| x.downcast_ref::>()) { - spawner.parallel(stream_id) + spawner.config(stream_id) } else { - (1, false) + Config::default() }; - if ordered { + if config.ordered { // let queue = self // .reordering_queue // .get_or_insert(&(stream_id, type_id), task_count); @@ -159,9 +201,9 @@ impl BusInner { // } // Ok(()) - self.send_inner(msg, task_count).await + self.send_inner(msg, config).await } else { - self.send_inner(msg, task_count).await + self.send_inner(msg, config).await } } @@ -419,7 +461,7 @@ mod tests { async fn test_tasks_shared() { let bus = Bus::default(); - bus.register(SharedBuilder::new(2, 5, false, |sid, _tid| async move { + bus.register(SharedBuilder::new(2, 5, |sid, _tid| async move { Ok(TestConsumer(sid)) })) .await; diff --git a/src/producer.rs b/src/producer.rs index 51cc0a6..0eb1d2f 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -1,11 +1,15 @@ -use std::{marker::PhantomData, pin::pin, sync::Arc}; +use std::{ + marker::PhantomData, + pin::{pin, Pin}, + sync::Arc, +}; use futures::{Future, Stream, StreamExt}; use tokio::sync::Notify; use crate::{ builder::Builder, - chan::{channel, Sender}, + chan::Receiver, message::Msg, task::{TaskCounter, TaskSpawner}, BusInner, Error, IntoMessage, Message, @@ -46,14 +50,14 @@ where { fn spawn_task( &self, + rx: Receiver>, stream_id: u32, task_id: u32, abort: Arc, task_counter: Arc, bus: Arc, - ) -> Box>, Error>> + Send + '_> { - Box::new(async move { - let (tx, rx) = channel::>(self.builder.queue_size(stream_id, task_id)); + ) -> Pin> + Send + '_>> { + Box::pin(async move { let mut ctx = self.builder.build(stream_id, task_id).await?; let _handle = tokio::spawn(async move { @@ -102,7 +106,7 @@ where } }); - Ok(tx) + Ok(()) }) } @@ -110,7 +114,7 @@ where true } - fn parallel(&self, stream_id: u32) -> (u32, bool) { - self.builder.parallel(stream_id) + fn config(&self, stream_id: u32) -> crate::builder::Config { + self.builder.config(stream_id) } } diff --git a/src/task.rs b/src/task.rs index a6040c0..3106db0 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,13 +1,17 @@ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; use futures::Future; use tokio::sync::Notify; use crate::{ - chan::{BusSender, Sender}, + builder::Config, + chan::{BusSender, Receiver, Sender}, handler::HandlerSpawner, message::Msg, producer::ProducerSpawner, @@ -41,16 +45,17 @@ impl TaskCounter { } pub(crate) trait TaskSpawner: Send + Sync { - fn parallel(&self, stream_id: u32) -> (u32, bool); + fn config(&self, stream_id: u32) -> Config; fn is_producer(&self) -> bool; fn spawn_task( &self, + rx: Receiver>, stream_id: u32, task_id: u32, abort: Arc, task_counter: Arc, bus: Arc, - ) -> Box>, Error>> + Send + '_>; + ) -> Pin> + Send + '_>>; } pub(crate) struct TaskSpawnerWrapper { @@ -87,24 +92,22 @@ impl TaskSpawnerWrapper { #[inline] pub async fn spawn_task( &self, + (tx, rx): (Sender>, Receiver>), stream_id: u32, task_id: u32, abort: Arc, task_counter: Arc, bus: Arc, ) -> Result, Error> { - Ok(BusSender::new( - self.inner.is_producer(), - Box::into_pin( - self.inner - .spawn_task(stream_id, task_id, abort, task_counter, bus), - ) - .await?, - )) + self.inner + .spawn_task(rx, stream_id, task_id, abort, task_counter, bus) + .await?; + + Ok(BusSender::new(self.inner.is_producer(), tx)) } #[inline] - pub(crate) fn parallel(&self, stream_id: u32) -> (u32, bool) { - self.inner.parallel(stream_id) + pub(crate) fn config(&self, stream_id: u32) -> Config { + self.inner.config(stream_id) } }