register_mapper

This commit is contained in:
Andrey Tkachenko 2023-11-23 09:58:39 +04:00
parent 33db3cb15d
commit d57dd8fed9
4 changed files with 155 additions and 32 deletions

View File

@ -39,13 +39,25 @@ pub trait Builder<M: Message>: Send + Sync + 'static {
) -> impl Future<Output = Result<Self::Context, Error>> + Send + '_; ) -> impl Future<Output = Result<Self::Context, Error>> + Send + '_;
} }
pub struct DefaultBuilder<M: Message, H> { pub struct DefaultBuilder<M: Send + Sync, H: Send + Sync, C: Send + Sync, F: Send> {
config: Config, config: Config,
_m: PhantomData<(M, H)>, callback: C,
_m: PhantomData<(M, H, F)>,
} }
impl<M: Message, H> DefaultBuilder<M, H> { unsafe impl<M: Send + Sync, H: Send + Sync, C: Send + Sync, F: Send> Sync
pub fn new(queue_size: usize) -> Self { for DefaultBuilder<M, H, C, F>
{
}
impl<M, H, C, F> DefaultBuilder<M, H, C, F>
where
M: Message,
H: Sync + Send + 'static,
F: Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static,
{
pub fn new(queue_size: usize, callback: C) -> Self {
Self { Self {
config: Config { config: Config {
queue_size, queue_size,
@ -54,6 +66,7 @@ impl<M: Message, H> DefaultBuilder<M, H> {
task_count: 1, task_count: 1,
lazy_task_creation: true, lazy_task_creation: true,
}, },
callback,
_m: PhantomData, _m: PhantomData,
} }
} }
@ -64,16 +77,34 @@ impl<M: Message, H> DefaultBuilder<M, H> {
Self { Self {
config, config,
callback: self.callback,
_m: PhantomData,
}
}
pub fn tasks(self, tasks: u32) -> Self {
let mut config = self.config;
config.task_count = tasks;
Self {
config,
callback: self.callback,
_m: PhantomData, _m: PhantomData,
} }
} }
} }
impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuilder<M, H> { impl<M, H, C, F> Builder<M> for DefaultBuilder<M, H, C, F>
where
M: Message,
H: Sync + Send + 'static,
F: Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static,
{
type Context = H; type Context = H;
async fn build(&self, _stream_id: u32, _task_id: u32) -> Result<Self::Context, Error> { async fn build(&self, stream_id: u32, task_id: u32) -> Result<Self::Context, Error> {
Ok(<Self::Context as Default>::default()) (self.callback)(stream_id, task_id).await
} }
fn config(&self, _stream_id: u32) -> Config { fn config(&self, _stream_id: u32) -> Config {
@ -88,11 +119,16 @@ pub struct SharedBuilder<M, H, C, F> {
_m: PhantomData<(M, F)>, _m: PhantomData<(M, F)>,
} }
unsafe impl<M: Send + Sync, H: Send + Sync, C: Send + Sync, F: Send> Sync
for SharedBuilder<M, H, C, F>
{
}
impl<M, H, C, F> SharedBuilder<M, H, C, F> impl<M, H, C, F> SharedBuilder<M, H, C, F>
where where
M: Message, M: Message,
H: Sync + Send + 'static, H: Sync + Send + 'static,
F: Sync + Send + Future<Output = Result<H, Error>> + 'static, F: Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static,
{ {
pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self { pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self {
@ -139,7 +175,7 @@ impl<M, H, C, F> Builder<M> for SharedBuilder<M, H, C, F>
where where
M: Message, M: Message,
H: Sync + Send + 'static, H: Sync + Send + 'static,
F: Sync + Send + Future<Output = Result<H, Error>> + 'static, F: Send + Future<Output = Result<H, Error>> + 'static,
C: Sync + Send + Fn(u32, u32) -> F + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static,
{ {
type Context = Arc<H>; type Context = Arc<H>;

View File

@ -14,7 +14,16 @@ pub enum Error {
HandlerError(Arc<dyn ErrorMessage>), HandlerError(Arc<dyn ErrorMessage>),
} }
#[derive(Debug)]
pub enum VoidError {}
impl std::fmt::Display for VoidError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result {
Ok(())
}
}
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl ErrorMessage for VoidError {}
impl Eq for Error {} impl Eq for Error {}

View File

@ -14,6 +14,7 @@ mod task;
use std::{ use std::{
any::{Any, TypeId}, any::{Any, TypeId},
collections::HashMap, collections::HashMap,
marker::PhantomData,
sync::{ sync::{
atomic::{AtomicBool, AtomicU64, Ordering}, atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Arc,
@ -32,7 +33,7 @@ use tokio::sync::{Notify, RwLock};
pub use async_iter::*; pub use async_iter::*;
pub use builder::{Builder, DefaultBuilder, SharedBuilder}; pub use builder::{Builder, DefaultBuilder, SharedBuilder};
pub use error::Error; pub use error::{Error, VoidError};
pub use handler::Handler; pub use handler::Handler;
pub use message::{ErrorMessage, IntoMessages, Message}; pub use message::{ErrorMessage, IntoMessages, Message};
@ -193,10 +194,12 @@ impl BusInner {
} }
} else { } else {
let (tx, rx) = channel::<Msg<M>>(config.queue_size); let (tx, rx) = channel::<Msg<M>>(config.queue_size);
for tid in 0..config.task_count - 1 { for tid in 0..config.task_count - 1 {
self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid) self.spawn_task(&spawner, (tx.clone(), rx.clone()), stream_id, tid)
.await?; .await?;
} }
self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1) self.spawn_task(&spawner, (tx, rx), stream_id, config.task_count - 1)
.await?; .await?;
}; };
@ -244,8 +247,6 @@ impl BusInner {
let mut queue = ReorderQueueInner::new(reorder_buff as _); let mut queue = ReorderQueueInner::new(reorder_buff as _);
while let Ok(msg) = rx.recv().await { while let Ok(msg) = rx.recv().await {
println!(" > {:?}", msg);
if let Some(index) = queue.push(msg.index, msg.inner) { if let Some(index) = queue.push(msg.index, msg.inner) {
bus.send_inner( bus.send_inner(
Some(Err::<M, _>(Error::ReorderingDropMessage(index))), Some(Err::<M, _>(Error::ReorderingDropMessage(index))),
@ -372,6 +373,24 @@ impl Bus {
self self
} }
#[inline]
pub async fn register_mapper<
M: Message,
R: Message,
E: ErrorMessage,
C: Send + Clone + Sync + FnMut(u32, u32, M) -> Result<R, E> + 'static,
>(
&self,
cb: C,
) -> &Self {
let mapper = DefaultBuilder::new(0, move |_, _| {
let cb = cb.clone();
async move { Ok(Mapper { cb, m: PhantomData }) }
});
self.inner.clone().register(mapper).await;
self
}
#[inline] #[inline]
pub async fn send<M: Message>(&self, inner: M) -> Result<(), Error> { pub async fn send<M: Message>(&self, inner: M) -> Result<(), Error> {
self.send_with_stream(DEFAUL_STREAM_ID, inner).await self.send_with_stream(DEFAUL_STREAM_ID, inner).await
@ -417,6 +436,46 @@ impl Bus {
} }
} }
struct Mapper<M, R, E, C> {
cb: C,
m: PhantomData<(M, R, E)>,
}
impl<M: Message, R: Message, E: ErrorMessage, C> Handler<M> for Mapper<M, R, E, C>
where
M: Message,
R: Message,
E: ErrorMessage,
C: Send + Sync + FnMut(u32, u32, M) -> Result<R, E> + 'static,
{
type Result = R;
type Error = E;
async fn handle(
&mut self,
_msg: M,
_stream_id: u32,
_task_id: u32,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
(self.cb)(_stream_id, _task_id, _msg).map(|x| [x])
}
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
_bus: Bus,
) -> Result<impl IntoMessages<Self::Result, Self::Error>, Self::Error> {
Ok(None)
}
async fn finalize(self, _bus: Bus) -> Result<(), Self::Error> {
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
@ -631,11 +690,15 @@ mod tests {
async fn test_streams() { async fn test_streams() {
let bus = Bus::default(); let bus = Bus::default();
bus.register(DefaultBuilder::<u64, TestConsumer>::new(2)) bus.register(DefaultBuilder::<u64, _, _, _>::new(2, |_, _| async move {
.await; Ok(TestConsumer::default())
}))
.await;
bus.register(DefaultBuilder::<u32, TestProducer>::new(2)) bus.register(DefaultBuilder::<u32, _, _, _>::new(2, |_, _| async move {
.await; Ok(TestProducer)
}))
.await;
for start in 0u32..10 { for start in 0u32..10 {
bus.send_with_stream(start, start).await.unwrap(); bus.send_with_stream(start, start).await.unwrap();
@ -655,8 +718,10 @@ mod tests {
})) }))
.await; .await;
bus.register(DefaultBuilder::<u32, TestProducer>::new(2)) bus.register(DefaultBuilder::<u32, _, _, _>::new(2, |_, _| async move {
.await; Ok(TestProducer)
}))
.await;
for start in 0u32..10 { for start in 0u32..10 {
bus.send_with_stream(start, start).await.unwrap(); bus.send_with_stream(start, start).await.unwrap();
@ -680,8 +745,10 @@ mod tests {
) )
.await; .await;
bus.register(DefaultBuilder::<_, TestCollector>::new(4)) bus.register(DefaultBuilder::<_, _, _, _>::new(4, |_, _| async move {
.await; Ok(TestCollector::default())
}))
.await;
for i in 0i16..1024 { for i in 0i16..1024 {
bus.send(i).await.unwrap(); bus.send(i).await.unwrap();

View File

@ -41,9 +41,17 @@ impl<M> ReorderQueueInner<M> {
} }
pub fn push(&mut self, index: u64, inner: M) -> Option<u64> { pub fn push(&mut self, index: u64, inner: M) -> Option<u64> {
if let Some(ri) = self.recent_index {
if index <= ri {
return Some(index);
}
}
self.heap.push(Entry { inner, index }); self.heap.push(Entry { inner, index });
if self.heap.len() == self.cap { if self.heap.len() > self.cap {
let _ = self.heap.pop();
self.recent_index = self.recent_index.map(|x| x + 1); self.recent_index = self.recent_index.map(|x| x + 1);
self.recent_index self.recent_index
} else { } else {
@ -52,20 +60,23 @@ impl<M> ReorderQueueInner<M> {
} }
pub fn pop(&mut self) -> Option<(u64, M)> { pub fn pop(&mut self) -> Option<(u64, M)> {
match self.recent_index { if let Some(ri) = self.recent_index {
None => { let e = self.heap.peek()?;
if e.index == ri + 1 {
self.recent_index = Some(e.index);
Some((e.index, self.heap.pop()?.inner))
} else {
None
}
} else {
let e = self.heap.peek()?;
if e.index == 0 {
let e = self.heap.pop()?; let e = self.heap.pop()?;
self.recent_index = Some(e.index); self.recent_index = Some(e.index);
Some((e.index, e.inner)) Some((e.index, e.inner))
} } else {
Some(ri) => { None
let e = self.heap.peek()?;
if e.index == ri + 1 {
self.recent_index = Some(e.index);
Some((e.index, self.heap.pop()?.inner))
} else {
None
}
} }
} }
} }