From d57dd8fed97c5a6f92ce5ff201a80d38e65b7af4 Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Thu, 23 Nov 2023 09:58:39 +0400 Subject: [PATCH] register_mapper --- src/builder.rs | 54 ++++++++++++++++++++++----- src/error.rs | 9 +++++ src/lib.rs | 89 ++++++++++++++++++++++++++++++++++++++------ src/reorder_queue.rs | 35 +++++++++++------ 4 files changed, 155 insertions(+), 32 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index a241364..7383d60 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -39,13 +39,25 @@ pub trait Builder: Send + Sync + 'static { ) -> impl Future> + Send + '_; } -pub struct DefaultBuilder { +pub struct DefaultBuilder { config: Config, - _m: PhantomData<(M, H)>, + callback: C, + _m: PhantomData<(M, H, F)>, } -impl DefaultBuilder { - pub fn new(queue_size: usize) -> Self { +unsafe impl Sync + for DefaultBuilder +{ +} + +impl DefaultBuilder +where + M: Message, + H: Sync + Send + 'static, + F: Send + Future> + 'static, + C: Sync + Send + Fn(u32, u32) -> F + 'static, +{ + pub fn new(queue_size: usize, callback: C) -> Self { Self { config: Config { queue_size, @@ -54,6 +66,7 @@ impl DefaultBuilder { task_count: 1, lazy_task_creation: true, }, + callback, _m: PhantomData, } } @@ -64,16 +77,34 @@ impl DefaultBuilder { Self { config, + callback: self.callback, + _m: PhantomData, + } + } + + pub fn tasks(self, tasks: u32) -> Self { + let mut config = self.config; + config.task_count = tasks; + + Self { + config, + callback: self.callback, _m: PhantomData, } } } -impl Builder for DefaultBuilder { +impl Builder for DefaultBuilder +where + M: Message, + H: Sync + Send + 'static, + F: Send + Future> + 'static, + C: Sync + Send + Fn(u32, u32) -> F + 'static, +{ type Context = H; - async fn build(&self, _stream_id: u32, _task_id: u32) -> Result { - Ok(::default()) + async fn build(&self, stream_id: u32, task_id: u32) -> Result { + (self.callback)(stream_id, task_id).await } fn config(&self, _stream_id: u32) -> Config { @@ -88,11 +119,16 @@ pub struct SharedBuilder { _m: PhantomData<(M, F)>, } +unsafe impl Sync + for SharedBuilder +{ +} + impl SharedBuilder where M: Message, H: Sync + Send + 'static, - F: Sync + Send + Future> + 'static, + F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self { @@ -139,7 +175,7 @@ impl Builder for SharedBuilder where M: Message, H: Sync + Send + 'static, - F: Sync + Send + Future> + 'static, + F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { type Context = Arc; diff --git a/src/error.rs b/src/error.rs index 6518c54..bc99b93 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,7 +14,16 @@ pub enum Error { HandlerError(Arc), } +#[derive(Debug)] +pub enum VoidError {} +impl std::fmt::Display for VoidError { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + impl std::error::Error for Error {} +impl ErrorMessage for VoidError {} impl Eq for Error {} diff --git a/src/lib.rs b/src/lib.rs index 5f2831a..da7dac9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ mod task; use std::{ any::{Any, TypeId}, collections::HashMap, + marker::PhantomData, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, Arc, @@ -32,7 +33,7 @@ use tokio::sync::{Notify, RwLock}; pub use async_iter::*; pub use builder::{Builder, DefaultBuilder, SharedBuilder}; -pub use error::Error; +pub use error::{Error, VoidError}; pub use handler::Handler; pub use message::{ErrorMessage, IntoMessages, Message}; @@ -193,10 +194,12 @@ impl BusInner { } } else { let (tx, rx) = channel::>(config.queue_size); + for tid in 0..config.task_count - 1 { self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid) .await?; } + self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1) .await?; }; @@ -244,8 +247,6 @@ impl BusInner { let mut queue = ReorderQueueInner::new(reorder_buff as _); while let Ok(msg) = rx.recv().await { - println!(" > {:?}", msg); - if let Some(index) = queue.push(msg.index, msg.inner) { bus.send_inner( Some(Err::(Error::ReorderingDropMessage(index))), @@ -372,6 +373,24 @@ impl Bus { self } + #[inline] + pub async fn register_mapper< + M: Message, + R: Message, + E: ErrorMessage, + C: Send + Clone + Sync + FnMut(u32, u32, M) -> Result + 'static, + >( + &self, + cb: C, + ) -> &Self { + let mapper = DefaultBuilder::new(0, move |_, _| { + let cb = cb.clone(); + async move { Ok(Mapper { cb, m: PhantomData }) } + }); + self.inner.clone().register(mapper).await; + self + } + #[inline] pub async fn send(&self, inner: M) -> Result<(), Error> { self.send_with_stream(DEFAUL_STREAM_ID, inner).await @@ -417,6 +436,46 @@ impl Bus { } } +struct Mapper { + cb: C, + m: PhantomData<(M, R, E)>, +} + +impl Handler for Mapper +where + M: Message, + R: Message, + E: ErrorMessage, + C: Send + Sync + FnMut(u32, u32, M) -> Result + 'static, +{ + type Result = R; + type Error = E; + + async fn handle( + &mut self, + _msg: M, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + (self.cb)(_stream_id, _task_id, _msg).map(|x| [x]) + } + + async fn handle_error( + &mut self, + _err: Error, + _stream_id: u32, + _task_id: u32, + _bus: Bus, + ) -> Result, Self::Error> { + Ok(None) + } + + async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> { + Ok(()) + } +} + #[cfg(test)] mod tests { use std::{sync::Arc, time::Duration}; @@ -631,11 +690,15 @@ mod tests { async fn test_streams() { let bus = Bus::default(); - bus.register(DefaultBuilder::::new(2)) - .await; + bus.register(DefaultBuilder::::new(2, |_, _| async move { + Ok(TestConsumer::default()) + })) + .await; - bus.register(DefaultBuilder::::new(2)) - .await; + bus.register(DefaultBuilder::::new(2, |_, _| async move { + Ok(TestProducer) + })) + .await; for start in 0u32..10 { bus.send_with_stream(start, start).await.unwrap(); @@ -655,8 +718,10 @@ mod tests { })) .await; - bus.register(DefaultBuilder::::new(2)) - .await; + bus.register(DefaultBuilder::::new(2, |_, _| async move { + Ok(TestProducer) + })) + .await; for start in 0u32..10 { bus.send_with_stream(start, start).await.unwrap(); @@ -680,8 +745,10 @@ mod tests { ) .await; - bus.register(DefaultBuilder::<_, TestCollector>::new(4)) - .await; + bus.register(DefaultBuilder::<_, _, _, _>::new(4, |_, _| async move { + Ok(TestCollector::default()) + })) + .await; for i in 0i16..1024 { bus.send(i).await.unwrap(); diff --git a/src/reorder_queue.rs b/src/reorder_queue.rs index dd9190e..b3e4153 100644 --- a/src/reorder_queue.rs +++ b/src/reorder_queue.rs @@ -41,9 +41,17 @@ impl ReorderQueueInner { } pub fn push(&mut self, index: u64, inner: M) -> Option { + if let Some(ri) = self.recent_index { + if index <= ri { + return Some(index); + } + } + self.heap.push(Entry { inner, index }); - if self.heap.len() == self.cap { + if self.heap.len() > self.cap { + let _ = self.heap.pop(); + self.recent_index = self.recent_index.map(|x| x + 1); self.recent_index } else { @@ -52,20 +60,23 @@ impl ReorderQueueInner { } pub fn pop(&mut self) -> Option<(u64, M)> { - match self.recent_index { - None => { + if let Some(ri) = self.recent_index { + let e = self.heap.peek()?; + if e.index == ri + 1 { + self.recent_index = Some(e.index); + Some((e.index, self.heap.pop()?.inner)) + } else { + None + } + } else { + let e = self.heap.peek()?; + + if e.index == 0 { let e = self.heap.pop()?; self.recent_index = Some(e.index); Some((e.index, e.inner)) - } - Some(ri) => { - let e = self.heap.peek()?; - if e.index == ri + 1 { - self.recent_index = Some(e.index); - Some((e.index, self.heap.pop()?.inner)) - } else { - None - } + } else { + None } } }