Add handle_error; continue ReorderingQueue feature

This commit is contained in:
Andrey Tkachenko 2023-11-09 17:09:48 +04:00
parent e5497ac76c
commit be2c24e349
8 changed files with 347 additions and 350 deletions

View File

@ -1,4 +1,5 @@
#![feature(impl_trait_in_assoc_type)] #![feature(return_position_impl_trait_in_trait)]
#![feature(async_fn_in_trait)]
use std::sync::Arc; use std::sync::Arc;
@ -9,48 +10,60 @@ pub struct Msg(pub i32);
impl Message for Msg {} impl Message for Msg {}
pub struct Processor { pub struct Processor {
state: i32, _state: i32,
} }
impl Handler<Msg> for Processor { impl Handler<Msg> for Processor {
type Result = (); type Result = ();
type IntoMessage = impl IntoMessage<Self::Result>; async fn handle(
type HandleFut<'a> = impl futures::Future<Output = Result<Self::IntoMessage, Error>> + 'a; &mut self,
type FinalizeFut<'a> = impl futures::Future<Output = Result<(), Error>> + 'a; _msg: Msg,
_stream_id: u32,
fn handle(&mut self, msg: Msg, _stream_id: u32, _task_id: u32) -> Self::HandleFut<'_> { _task_id: u32,
async move { Ok(()) } ) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(())
} }
fn finalize<'a>(self) -> Self::FinalizeFut<'a> { async fn finalize(self) -> Result<(), Error> {
async move { Ok(()) } Ok(())
}
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(None)
} }
} }
struct ProcSpawner; struct ProcSpawner;
impl Builder<Msg> for ProcSpawner { impl Builder<Msg> for ProcSpawner {
type Context = Processor; type Context = Processor;
type BuildFut<'a> = impl futures::Future<Output = Result<Self::Context, Error>> + 'a;
fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_> { async fn build(&self, stream_id: u32, _task_id: u32) -> Result<Self::Context, Error> {
async move {
Ok(Processor { Ok(Processor {
state: stream_id as _, _state: stream_id as _,
}) })
} }
}
} }
impl Processor { impl Processor {
pub async fn spawn(sid: u32) -> Result<(usize, Self), Error> { pub async fn spawn(_sid: u32) -> Result<(usize, Self), Error> {
Ok((4, Self { state: 0 })) Ok((4, Self { _state: 0 }))
} }
pub async fn handler_msg(self: Arc<Self>, sid: u32, tid: u32, msg: Msg) -> Result<(), Error> { pub async fn handler_msg(
self: Arc<Self>,
_sid: u32,
_tid: u32,
_msg: Msg,
) -> Result<(), Error> {
Ok(()) Ok(())
} }
pub async fn finalize_msg_handler(self: Arc<Self>, sid: u32) -> Result<(), Error> { pub async fn finalize_msg_handler(self: Arc<Self>, _sid: u32) -> Result<(), Error> {
Ok(()) Ok(())
} }
} }

View File

@ -27,15 +27,16 @@ impl Default for Config {
pub trait Builder<M: Message>: Send + Sync + 'static { pub trait Builder<M: Message>: Send + Sync + 'static {
type Context: 'static; type Context: 'static;
type BuildFut<'a>: Future<Output = Result<Self::Context, Error>> + Send + 'a
where
Self: 'a;
fn config(&self, _stream_id: u32) -> Config { fn config(&self, _stream_id: u32) -> Config {
Default::default() Default::default()
} }
fn build(&self, stream_id: u32, _task_id: u32) -> Self::BuildFut<'_>; fn build(
&self,
stream_id: u32,
_task_id: u32,
) -> impl Future<Output = Result<Self::Context, Error>> + Send + '_;
} }
pub struct DefaultBuilder<M: Message, H> { pub struct DefaultBuilder<M: Message, H> {
@ -60,10 +61,9 @@ impl<M: Message, H> DefaultBuilder<M, H> {
impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuilder<M, H> { impl<M: Message, H: Sync + Send + Default + 'static> Builder<M> for DefaultBuilder<M, H> {
type Context = H; type Context = H;
type BuildFut<'a> = impl Future<Output = Result<Self::Context, Error>> + Send + 'a;
fn build(&self, _stream_id: u32, _task_id: u32) -> Self::BuildFut<'_> { async fn build(&self, _stream_id: u32, _task_id: u32) -> Result<Self::Context, Error> {
async move { Ok(<Self::Context as Default>::default()) } Ok(<Self::Context as Default>::default())
} }
fn config(&self, _stream_id: u32) -> Config { fn config(&self, _stream_id: u32) -> Config {
@ -133,10 +133,8 @@ where
C: Sync + Send + Fn(u32, u32) -> F + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static,
{ {
type Context = Arc<H>; type Context = Arc<H>;
type BuildFut<'a> = impl Future<Output = Result<Self::Context, Error>> + Send + 'a;
fn build(&self, stream_id: u32, task_id: u32) -> Self::BuildFut<'_> { async fn build(&self, stream_id: u32, task_id: u32) -> Result<Self::Context, Error> {
async move {
if self.stream_handlers.contains_key(&stream_id) { if self.stream_handlers.contains_key(&stream_id) {
return Ok(self.stream_handlers.get(&stream_id).unwrap().clone()); return Ok(self.stream_handlers.get(&stream_id).unwrap().clone());
} }
@ -145,7 +143,6 @@ where
self.stream_handlers.insert(stream_id, val.clone()); self.stream_handlers.insert(stream_id, val.clone());
Ok(val.clone()) Ok(val.clone())
} }
}
fn config(&self, _stream_id: u32) -> Config { fn config(&self, _stream_id: u32) -> Config {
self.config self.config

View File

@ -1,9 +1,27 @@
#[derive(Debug)] use kanal::ReceiveError;
#[derive(Debug, PartialEq, Eq)]
pub enum Error { pub enum Error {
HandlerIsNotRegistered, HandlerIsNotRegistered,
Aborted, Aborted,
SendError(String), SendError(String),
ReceiveError(kanal::ReceiveError), ReceiveError(kanal::ReceiveError),
ReorderingMissedMessage(u64),
}
impl Clone for Error {
fn clone(&self) -> Self {
match self {
Error::SendError(err) => Error::SendError(err.clone()),
Error::ReceiveError(err) => match err {
ReceiveError::Closed => Error::ReceiveError(ReceiveError::Closed),
ReceiveError::SendClosed => Error::ReceiveError(ReceiveError::SendClosed),
},
Error::HandlerIsNotRegistered => Error::HandlerIsNotRegistered,
Error::Aborted => Error::Aborted,
Error::ReorderingMissedMessage(idx) => Error::ReorderingMissedMessage(*idx),
}
}
} }
impl From<kanal::SendError> for Error { impl From<kanal::SendError> for Error {

View File

@ -1,9 +1,4 @@
use std::{ use std::{any::Any, marker::PhantomData, pin::Pin, sync::Arc};
any::{Any, TypeId},
marker::PhantomData,
pin::Pin,
sync::Arc,
};
use futures::Future; use futures::Future;
use tokio::sync::Notify; use tokio::sync::Notify;
@ -18,17 +13,22 @@ use crate::{
pub trait Handler<M: Message>: Send + Sync + 'static { pub trait Handler<M: Message>: Send + Sync + 'static {
type Result: Message; type Result: Message;
type IntoMessage: IntoMessage<Self::Result>;
type HandleFut<'a>: Future<Output = Result<Self::IntoMessage, Error>> + Send + 'a
where
Self: 'a;
type FinalizeFut<'a>: Future<Output = Result<(), Error>> + Send + 'a fn handle(
where &mut self,
Self: 'a; msg: M,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Result>, Error>> + Send + '_;
fn handle(&mut self, msg: M, stream_id: u32, task_id: u32) -> Self::HandleFut<'_>; fn handle_error(
fn finalize<'a>(self) -> Self::FinalizeFut<'a>; &mut self,
err: Error,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Result>, Error>> + Send + '_;
fn finalize(self) -> impl Future<Output = Result<(), Error>> + Send;
} }
pub(crate) struct HandlerSpawner<M, B> { pub(crate) struct HandlerSpawner<M, B> {
@ -66,16 +66,23 @@ where
while let Some(msg) = rx.recv().await { while let Some(msg) = rx.recv().await {
task_counter.inc_running(); task_counter.inc_running();
let res_msg = match ctx.handle(msg.inner, stream_id, task_id).await { let resp = match msg.inner {
Ok(res) => res.into_message(), Ok(m) => ctx
Err(err) => { .handle(m, stream_id, task_id)
println!("TASK HANDLE ERROR: {:?}", err); .await
continue; .map(IntoMessage::into_message),
}
Err(err) => ctx
.handle_error(err, stream_id, task_id)
.await
.map(IntoMessage::into_message),
};
let Some(inner) = resp.transpose() else {
let _ = bus.send_skip::<M>(stream_id, msg.index).await;
continue;
}; };
if let Some(inner) = res_msg {
if inner.type_id() != TypeId::of::<()>() {
if let Err(err) = bus if let Err(err) = bus
.send(Msg { .send(Msg {
inner, inner,
@ -85,9 +92,6 @@ where
.await .await
{ {
println!("BUS SEND ERROR: {:?}", err); println!("BUS SEND ERROR: {:?}", err);
continue;
}
}
} }
task_counter.dec_running(rx.is_empty()); task_counter.dec_running(rx.is_empty());

View File

@ -1,4 +1,5 @@
#![feature(impl_trait_in_assoc_type)] #![feature(return_position_impl_trait_in_trait)]
#![feature(async_fn_in_trait)]
mod builder; mod builder;
mod chan; mod chan;
@ -111,11 +112,12 @@ impl BusInner {
async fn send_inner<M: Message>( async fn send_inner<M: Message>(
self: &Arc<Self>, self: &Arc<Self>,
msg: Msg<M>, msg: Result<M, Error>,
index: u64,
stream_id: u32,
config: Config, config: Config,
) -> Result<(), Error> { ) -> Result<(), Error> {
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
let stream_id = msg.stream_id;
let task_id = self.get_task_id::<M>(stream_id, &config); let task_id = self.get_task_id::<M>(stream_id, &config);
if !self.senders.contains_key(&(stream_id, task_id, type_id)) { if !self.senders.contains_key(&(stream_id, task_id, type_id)) {
@ -169,7 +171,11 @@ impl BusInner {
.upcast() .upcast()
.downcast_ref::<BusSender<M>>() .downcast_ref::<BusSender<M>>()
.unwrap() .unwrap()
.send(msg) .send(Msg {
inner: msg,
index,
stream_id,
})
.await .await
.unwrap(); .unwrap();
@ -178,6 +184,10 @@ impl BusInner {
pub async fn send<M: Message>(self: &Arc<Self>, msg: Msg<M>) -> Result<(), Error> { pub async fn send<M: Message>(self: &Arc<Self>, msg: Msg<M>) -> Result<(), Error> {
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let stream_id = msg.stream_id; let stream_id = msg.stream_id;
let config = if let Some(spawner) = self let config = if let Some(spawner) = self
@ -196,36 +206,48 @@ impl BusInner {
let mut queue = self let mut queue = self
.reordering .reordering
.entry((stream_id, type_id)) .entry((stream_id, type_id))
.or_insert_with(|| Box::new(ReorderQueue::<M>::new(config.task_count as _))); .or_insert_with(|| {
Box::new(ReorderQueue::<Option<Result<M, Error>>>::new(
config.task_count as _,
))
});
let queue = queue.downcast_mut::<ReorderQueue<M>>().unwrap(); let queue = queue
.downcast_mut::<ReorderQueue<Option<Result<M, Error>>>>()
.unwrap();
if let Some(index) = queue.push(msg) { if let Some(index) = queue.push(msg.index, Some(msg.inner)) {
log::warn!( self.send_inner(
"!!! Reordering queue overflow: dropping message {} with index {}", Err::<M, _>(Error::ReorderingMissedMessage(index)),
std::any::type_name::<M>(), index,
index stream_id,
); config,
// self.send_error(); )
.await?;
} }
while let Some(msg) = queue.pop() { while let Some(msg) = queue.pop() {
self.send_inner(msg, config).await?; if let (index, Some(Some(msg))) = msg {
self.send_inner(msg, index, stream_id, config).await?;
}
} }
Ok(()) Ok(())
} else { } else {
self.send_inner(msg, config).await self.send_inner(msg.inner, msg.index, stream_id, config)
.await
} }
} }
pub async fn send_error<M: Message>( pub async fn send_skip<M: Message>(
self: &Arc<Self>, self: &Arc<Self>,
stream_id: u32, stream_id: u32,
index: u64, index: u64,
err: Error,
) -> Result<(), Error> { ) -> Result<(), Error> {
let type_id = TypeId::of::<M>(); let type_id = TypeId::of::<M>();
if type_id == TypeId::of::<()>() {
return Ok(());
}
let config = if let Some(spawner) = self let config = if let Some(spawner) = self
.spawners .spawners
@ -243,15 +265,31 @@ impl BusInner {
let mut queue = self let mut queue = self
.reordering .reordering
.entry((stream_id, type_id)) .entry((stream_id, type_id))
.or_insert_with(|| Box::new(ReorderQueue::<M>::new(config.task_count as _))); .or_insert_with(|| {
Box::new(ReorderQueue::<Option<Result<M, Error>>>::new(
config.task_count as _,
))
});
let queue = queue.downcast_mut::<ReorderQueue<M>>().unwrap(); let queue = queue
.downcast_mut::<ReorderQueue<Option<Result<M, Error>>>>()
.unwrap();
while let Some(msg) = queue.pop() { if let Some(index) = queue.push(index, None) {
self.send_inner(msg, config).await?; self.send_inner(
Err::<M, _>(Error::ReorderingMissedMessage(index)),
index,
stream_id,
config,
)
.await?;
} }
Ok(()) while let Some(msg) = queue.pop() {
if let (index, Some(Some(msg))) = msg {
self.send_inner(msg, index, stream_id, config).await?;
}
}
} }
Ok(()) Ok(())
@ -341,7 +379,7 @@ impl Bus {
self.inner self.inner
.send(Msg { .send(Msg {
inner, inner: Ok(inner),
index, index,
stream_id: DEFAUL_STREAM_ID, stream_id: DEFAUL_STREAM_ID,
}) })
@ -363,7 +401,7 @@ impl Bus {
self.inner self.inner
.send(Msg { .send(Msg {
inner, inner: Ok(inner),
index, index,
stream_id, stream_id,
}) })
@ -397,7 +435,7 @@ mod tests {
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use async_stream::stream; use async_stream::stream;
use futures::{Future, Stream}; use futures::Stream;
use rand::RngCore; use rand::RngCore;
use crate::{ use crate::{
@ -411,12 +449,13 @@ mod tests {
struct TestProducer; struct TestProducer;
impl Producer<u32> for TestProducer { impl Producer<u32> for TestProducer {
type Item = u64; type Item = u64;
type IntoMessage = impl IntoMessage<Self::Item>;
type Stream<'a> = impl Stream<Item = Result<Self::IntoMessage, Error>> + Send + 'a; fn stream(
type FinalizeFut<'a> = impl Future<Output = Result<(), Error>> + Send + 'a; &mut self,
_msg: u32,
fn stream(&mut self, _msg: u32, _stream_id: u32, _task_id: u32) -> Self::Stream<'_> { _stream_id: u32,
_task_id: u32,
) -> impl Stream<Item = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_ {
stream! { stream! {
for i in 0u64..10 { for i in 0u64..10 {
yield Ok(i) yield Ok(i)
@ -424,13 +463,20 @@ mod tests {
} }
} }
fn finalize<'a>(self) -> Self::FinalizeFut<'a> { async fn handle_error(
async move { &mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Item>, Error> {
Ok(None)
}
async fn finalize(self) -> Result<(), Error> {
println!("producer finalized"); println!("producer finalized");
Ok(()) Ok(())
} }
} }
}
struct TestConsumer(u32); struct TestConsumer(u32);
impl Default for TestConsumer { impl Default for TestConsumer {
@ -441,12 +487,13 @@ mod tests {
impl Handler<u64> for Arc<TestConsumer> { impl Handler<u64> for Arc<TestConsumer> {
type Result = (); type Result = ();
type IntoMessage = impl IntoMessage<Self::Result>;
type HandleFut<'a> = impl Future<Output = Result<Self::IntoMessage, Error>> + Send + 'a;
type FinalizeFut<'a> = impl Future<Output = Result<(), Error>> + Send + 'a;
fn handle(&mut self, msg: u64, stream_id: u32, task_id: u32) -> Self::HandleFut<'_> { async fn handle(
async move { &mut self,
msg: u64,
stream_id: u32,
task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
tokio::time::sleep(Duration::from_millis(1000)).await; tokio::time::sleep(Duration::from_millis(1000)).await;
println!( println!(
"[{}] shared consumer handle {}u64 ({}:{})", "[{}] shared consumer handle {}u64 ({}:{})",
@ -454,24 +501,30 @@ mod tests {
); );
Ok(()) Ok(())
} }
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(None)
} }
fn finalize<'a>(self) -> Self::FinalizeFut<'a> { async fn finalize(self) -> Result<(), Error> {
async move {
println!("[{}] shared consumer finalized", self.0); println!("[{}] shared consumer finalized", self.0);
Ok(()) Ok(())
} }
} }
}
impl Handler<u64> for TestConsumer { impl Handler<u64> for TestConsumer {
type Result = (); type Result = ();
type IntoMessage = impl IntoMessage<Self::Result>;
type HandleFut<'a> = impl Future<Output = Result<Self::IntoMessage, Error>> + Send + 'a;
type FinalizeFut<'a> = impl Future<Output = Result<(), Error>> + Send + 'a;
fn handle(&mut self, msg: u64, stream_id: u32, task_id: u32) -> Self::HandleFut<'_> { async fn handle(
async move { &mut self,
msg: u64,
stream_id: u32,
task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
tokio::time::sleep(Duration::from_millis(1000)).await; tokio::time::sleep(Duration::from_millis(1000)).await;
println!( println!(
"[{}] consumer handle {}u64 ({}:{})", "[{}] consumer handle {}u64 ({}:{})",
@ -479,17 +532,24 @@ mod tests {
); );
Ok(()) Ok(())
} }
async fn handle_error(
&mut self,
_err: Error,
_stream_id: u32,
_task_id: u32,
) -> Result<impl IntoMessage<Self::Result>, Error> {
Ok(None)
} }
fn finalize<'a>(self) -> Self::FinalizeFut<'a> { async fn finalize(self) -> Result<(), Error> {
async move {
println!("[{}] consumer finalized", self.0); println!("[{}] consumer finalized", self.0);
Ok(()) Ok(())
} }
} }
}
// #[tokio::test] #[tokio::test]
#[ignore = ""]
async fn test_streams() { async fn test_streams() {
let bus = Bus::default(); let bus = Bus::default();
@ -507,7 +567,8 @@ mod tests {
bus.wait().await; bus.wait().await;
} }
// #[tokio::test] #[tokio::test]
#[ignore = ""]
async fn test_tasks_shared() { async fn test_tasks_shared() {
let bus = Bus::default(); let bus = Bus::default();

View File

@ -1,12 +1,24 @@
use core::fmt; use core::fmt;
use crate::Error;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Msg<M: Message> { pub(crate) struct Msg<M: Message> {
pub(crate) inner: M, pub(crate) inner: Result<M, Error>,
pub(crate) index: u64, pub(crate) index: u64,
pub(crate) stream_id: u32, pub(crate) stream_id: u32,
} }
// impl<M: Message> Msg<M> {
// pub(crate) fn new(m: M, index: u64, stream_id: u32) -> Self {
// Self {
// inner: Ok(m),
// index,
// stream_id,
// }
// }
// }
pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {} pub trait Message: fmt::Debug + Clone + Send + Sync + 'static {}
impl Message for () {} impl Message for () {}

View File

@ -17,17 +17,22 @@ use crate::{
pub trait Producer<M: Message>: Send + Sync + 'static { pub trait Producer<M: Message>: Send + Sync + 'static {
type Item: Message; type Item: Message;
type IntoMessage: IntoMessage<Self::Item>;
type Stream<'a>: Stream<Item = Result<Self::IntoMessage, Error>> + Send + 'a
where
Self: 'a;
type FinalizeFut<'a>: Future<Output = Result<(), Error>> + Send + 'a fn stream(
where &mut self,
Self: 'a; msg: M,
stream_id: u32,
task_id: u32,
) -> impl Stream<Item = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_;
fn stream(&mut self, msg: M, stream_id: u32, task_id: u32) -> Self::Stream<'_>; fn handle_error(
fn finalize<'a>(self) -> Self::FinalizeFut<'a>; &mut self,
err: Error,
stream_id: u32,
task_id: u32,
) -> impl Future<Output = Result<impl IntoMessage<Self::Item>, Error>> + Send + '_;
fn finalize(self) -> impl Future<Output = Result<(), Error>> + Send;
} }
pub(crate) struct ProducerSpawner<M, B> { pub(crate) struct ProducerSpawner<M, B> {
@ -64,18 +69,18 @@ where
while let Some(recv_msg) = rx.recv().await { while let Some(recv_msg) = rx.recv().await {
task_counter.inc_running(); task_counter.inc_running();
match recv_msg.inner {
Ok(msg) => {
let mut stream = pin!(ctx let mut stream = pin!(ctx
.stream(recv_msg.inner, stream_id, task_id) .stream(msg, stream_id, task_id)
.take_until(abort.notified())); .take_until(abort.notified()));
let mut index = 0; let mut index = 0;
loop { while let Some(res) = stream.next().await {
if let Some(inner) = res.map(IntoMessage::into_message).transpose()
{
index += 1; index += 1;
match stream.next().await {
Some(Ok(msg)) => {
if let Some(inner) = msg.into_message() {
if let Err(err) = bus if let Err(err) = bus
.send(Msg { .send(Msg {
inner, inner,
@ -89,12 +94,32 @@ where
} }
} }
} }
Some(Err(err)) => {
println!("PRODUCER ERROR: {:?}", err);
continue;
} }
Err(err) => {
let Some(inner) = ctx
.handle_error(err, stream_id, task_id)
.await
.map(IntoMessage::into_message)
.transpose()
else {
if let Err(err) =
bus.send_skip::<M>(stream_id, recv_msg.index).await
{
println!("BUS SEND ERROR: {:?}", err);
}
continue;
};
None => break, if let Err(err) = bus
.send(Msg {
inner,
index: recv_msg.index,
stream_id,
})
.await
{
println!("BUS SEND ERROR: {:?}", err);
}
} }
} }

View File

@ -1,36 +1,37 @@
use std::{cmp::Ordering, collections::BinaryHeap}; use std::{cmp::Ordering, collections::BinaryHeap};
use crate::{message::Msg, Message}; struct Entry<M> {
inner: Option<M>,
index: u64,
}
struct Entry<M: Message>(Msg<M>); impl<M> PartialOrd for Entry<M> {
impl<M: Message> PartialOrd for Entry<M> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(other.0.index.cmp(&self.0.index)) Some(other.index.cmp(&self.index))
} }
} }
impl<M: Message> Ord for Entry<M> { impl<M> Ord for Entry<M> {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
other.0.index.cmp(&self.0.index) other.index.cmp(&self.index)
} }
} }
impl<M: Message> PartialEq for Entry<M> { impl<M> PartialEq for Entry<M> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
other.0.index.eq(&self.0.index) other.index.eq(&self.index)
} }
} }
impl<M: Message> Eq for Entry<M> {} impl<M> Eq for Entry<M> {}
pub(crate) struct ReorderQueue<M: Message> { pub(crate) struct ReorderQueue<M> {
cap: usize, cap: usize,
recent_index: Option<u64>, recent_index: Option<u64>,
heap: BinaryHeap<Entry<M>>, heap: BinaryHeap<Entry<M>>,
} }
impl<M: Message> ReorderQueue<M> { impl<M> ReorderQueue<M> {
pub fn new(cap: usize) -> Self { pub fn new(cap: usize) -> Self {
Self { Self {
cap, cap,
@ -39,8 +40,11 @@ impl<M: Message> ReorderQueue<M> {
} }
} }
pub fn push(&mut self, msg: Msg<M>) -> Option<u64> { pub fn push(&mut self, index: u64, msg: M) -> Option<u64> {
self.heap.push(Entry(msg)); self.heap.push(Entry {
inner: Some(msg),
index,
});
if self.heap.len() == self.cap { if self.heap.len() == self.cap {
self.recent_index = self.recent_index.map(|x| x + 1); self.recent_index = self.recent_index.map(|x| x + 1);
@ -50,18 +54,18 @@ impl<M: Message> ReorderQueue<M> {
} }
} }
pub fn pop(&mut self) -> Option<Msg<M>> { pub fn pop(&mut self) -> Option<(u64, Option<M>)> {
match self.recent_index { match self.recent_index {
None => { None => {
let e = self.heap.pop()?; let e = self.heap.pop()?;
self.recent_index = Some(e.0.index); self.recent_index = Some(e.index);
Some(e.0) Some((e.index, e.inner))
} }
Some(ri) => { Some(ri) => {
let e = self.heap.peek()?; let e = self.heap.peek()?;
if e.0.index == ri + 1 { if e.index == ri + 1 {
self.recent_index = Some(e.0.index); self.recent_index = Some(e.index);
Some(self.heap.pop()?.0) Some((e.index, self.heap.pop()?.inner))
} else { } else {
None None
} }
@ -72,195 +76,58 @@ impl<M: Message> ReorderQueue<M> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::Message;
use super::ReorderQueue; use super::ReorderQueue;
impl Message for i32 {}
#[test] #[test]
fn test_reordering() { fn test_reordering() {
let mut queue = ReorderQueue::new(8); let mut queue = ReorderQueue::new(8);
assert_eq!( assert_eq!(queue.push(0, 0), None);
queue.push(crate::message::Msg { assert_eq!(queue.pop(), Some((0, Some(0))));
inner: (),
index: 0,
stream_id: 0,
}),
None
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 0,
stream_id: 0,
})
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(3, 3), None);
queue.push(crate::message::Msg {
inner: (),
index: 3,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(2, 2), None);
queue.push(crate::message::Msg {
inner: (),
index: 2,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(4, 4), None);
queue.push(crate::message::Msg {
inner: (),
index: 4,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(1, 1), None);
queue.push(crate::message::Msg { assert_eq!(queue.pop(), Some((1, Some(1))));
inner: (), assert_eq!(queue.pop(), Some((2, Some(2))));
index: 1, assert_eq!(queue.pop(), Some((3, Some(3))));
stream_id: 0, assert_eq!(queue.pop(), Some((4, Some(4))));
}),
None
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 1,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 2,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 3,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 4,
stream_id: 0,
})
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
} }
#[test] #[test]
fn test_overflow() { fn test_overflow() {
let mut queue = ReorderQueue::new(4); let mut queue = ReorderQueue::new(4);
assert_eq!( assert_eq!(queue.push(0, 0), None);
queue.push(crate::message::Msg { assert_eq!(queue.pop(), Some((0, Some(0))));
inner: (),
index: 0,
stream_id: 0,
}),
None
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 0,
stream_id: 0,
})
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(4, 4), None);
queue.push(crate::message::Msg {
inner: (),
index: 4,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(2, 2), None);
queue.push(crate::message::Msg {
inner: (),
index: 2,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(3, 3), None);
queue.push(crate::message::Msg {
inner: (),
index: 3,
stream_id: 0,
}),
None
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
assert_eq!( assert_eq!(queue.push(5, 5), Some(1));
queue.push(crate::message::Msg { assert_eq!(queue.pop(), Some((2, Some(2))));
inner: (), assert_eq!(queue.pop(), Some((3, Some(3))));
index: 5, assert_eq!(queue.pop(), Some((4, Some(4))));
stream_id: 0, assert_eq!(queue.pop(), Some((5, Some(5))));
}),
Some(1)
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 2,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 3,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 4,
stream_id: 0,
})
);
assert_eq!(
queue.pop(),
Some(crate::message::Msg {
inner: (),
index: 5,
stream_id: 0,
})
);
assert_eq!(queue.pop(), None); assert_eq!(queue.pop(), None);
} }
} }