TCP Relay

This commit is contained in:
Andrey Tkachenko 2021-10-08 13:03:24 +04:00
parent 1bff5cc70a
commit bc862cfac6
16 changed files with 1021 additions and 406 deletions

View File

@ -15,7 +15,7 @@ edition = "2018"
[dependencies]
thiserror = "1.0"
messagebus = "0.9"
messagebus = { path = "../../" }
tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "sync", "time", "io-util"] }
parking_lot = "0.11"
quinn = "0.7"

View File

@ -1,4 +1,5 @@
use messagebus::Bus;
use messagebus::error::GenericError;
use messagebus::{Bus, TypeTagged};
use messagebus_remote::relays::QuicClientRelay;
use serde_derive::{Serialize, Deserialize};
use messagebus::derive::Message;
@ -26,7 +27,7 @@ async fn main() {
"127.0.0.1:8083".parse().unwrap(),
"localhost".into(),
(vec![
("example::Req".into(), "example::Resp".into(), "GenericError".into())
(Req::type_tag_(), Some((Resp::type_tag_(), GenericError::type_tag_())))
],
vec![])
).unwrap();

View File

@ -1,4 +1,4 @@
use messagebus::{error, Message, derive::{Message, Error as MbError}, AsyncHandler, Bus};
use messagebus::{AsyncHandler, Bus, Message, TypeTagged, derive::{Message, Error as MbError}, error::{self, GenericError}};
use messagebus_remote::relays::{QuicServerRelay};
use serde_derive::{Serialize, Deserialize};
use async_trait::async_trait;
@ -65,7 +65,7 @@ async fn main() {
"0.0.0.0:8083".parse().unwrap(),
(vec![],
vec![
("example::Req".into(), "example::Resp".into(), "GenericError".into())
(Req::type_tag_(), Some((Resp::type_tag_(), GenericError::type_tag_())))
])
).unwrap();

View File

@ -0,0 +1,54 @@
use messagebus::error::GenericError;
use messagebus::{Bus, TypeTagged};
use messagebus_remote::relays::TcpRelay;
use serde_derive::{Serialize, Deserialize};
use messagebus::derive::Message;
#[derive(Serialize, Deserialize, Debug, Clone, Message)]
#[namespace("example")]
#[message(shared, clone)]
pub struct Req {
data: i32,
text: String
}
#[derive(Serialize, Deserialize, Debug, Clone, Message)]
#[namespace("example")]
#[message(shared, clone)]
pub struct Resp {
data: i32,
text: String
}
#[tokio::main]
async fn main() {
let relay = TcpRelay::new(false, "0.0.0.0:8083".parse().unwrap(),
(vec![
(Req::type_tag_(), Some((Resp::type_tag_(), GenericError::type_tag_())))
],
vec![])
);
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.build();
b.ready().await;
println!("ready");
let resp: Resp = b.request(Req {
data: 12,
text: String::from("test")
}, Default::default())
.await
.unwrap();
println!("resp {:?}", resp);
b.flush().await;
b.close().await;
poller.await;
}

View File

@ -0,0 +1,83 @@
use messagebus::{AsyncHandler, Bus, Message, TypeTagged, derive::{Message, Error as MbError}, error::{self, GenericError}};
use messagebus_remote::relays::TcpRelay;
use serde_derive::{Serialize, Deserialize};
use async_trait::async_trait;
use thiserror::Error;
#[derive(Debug, Error, MbError)]
enum Error {
#[error("Error({0})")]
Error(anyhow::Error),
}
impl<M: Message> From<error::Error<M>> for Error {
fn from(err: error::Error<M>) -> Self {
Self::Error(err.into())
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Message)]
#[namespace("example")]
#[message(shared, clone)]
pub struct Req {
data: i32,
text: String
}
#[derive(Serialize, Deserialize, Debug, Clone, Message)]
#[namespace("example")]
#[message(shared, clone)]
pub struct Resp {
data: i32,
text: String
}
struct TmpReceiver;
#[async_trait]
impl AsyncHandler<Req> for TmpReceiver {
type Error = Error;
type Response = Resp;
async fn handle(&self, msg: Req, _bus: &Bus) -> Result<Self::Response, Self::Error> {
println!("TmpReceiver::handle {:?}", msg);
Ok(Resp {
data: msg.data + 12,
text: format!("<< {} >>", msg.text),
})
}
async fn sync(&self, _bus: &Bus) -> Result<(), Self::Error> {
println!("TmpReceiver::sync");
Ok(())
}
}
#[tokio::main]
async fn main() {
let relay = TcpRelay::new(true, "0.0.0.0:8083".parse().unwrap(),
(vec![],
vec![
(Req::type_tag_(), Some((Resp::type_tag_(), GenericError::type_tag_())))
])
);
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.register(TmpReceiver)
.subscribe_async::<Req>(8, Default::default())
.done()
.build();
b.ready().await;
println!("ready");
poller.await;
}

View File

@ -34,8 +34,8 @@ pub enum Error {
#[error("Redis: {0}")]
Redis(#[from] redis::RedisError),
#[error("ProtocolParseError")]
ProtocolParseError,
#[error("ProtocolParseError {0}")]
ProtocolParseError(String),
#[error("UnknownCodec")]
UnknownCodec,

View File

@ -224,8 +224,10 @@ impl ProtocolItem {
ProtocolHeaderActionKind::Error
}
Event::Finished(n) => {
Event::BatchComplete(tt, n) => {
argument = *n;
type_tag = Some(tt.clone());
flags.set(ProtocolHeaderFlags::TYPE_TAG, true);
flags.set(ProtocolHeaderFlags::ARGUMENT, true);
ProtocolHeaderActionKind::BatchComplete
},
@ -299,15 +301,15 @@ impl<'a> ProtocolPacket<'a> {
let (body, error) = if self.header.flags.contains(ProtocolHeaderFlags::ERROR) {
let error = messagebus::error::GenericError {
type_tag: type_tag.unwrap(),
type_tag: type_tag.clone().unwrap(),
description: self.body.map(|x|String::from_utf8_lossy(x.as_ref()).to_string()).unwrap_or_default(),
};
(None, Some(messagebus::error::Error::Other(error)))
} else if self.header.flags.contains(ProtocolHeaderFlags::TT_AND_BODY) {
let body = self.body.ok_or(crate::error::Error::ProtocolParseError)?;
let body = self.body.ok_or_else(|| crate::error::Error::ProtocolParseError("No body".to_string()))?;
let res = generic_deserialize(self.header.body_type, body.as_ref(), |de| {
bus.deserialize_message(type_tag.unwrap(), de)
bus.deserialize_message(type_tag.clone().unwrap(), de)
.map_err(|x| x.map_msg(|_| ()))
})?;
@ -328,19 +330,19 @@ impl<'a> ProtocolPacket<'a> {
Ok(ProtocolItem::Event(match self.header.kind {
ProtocolHeaderActionKind::Response => Event::Response(
argument
.ok_or(crate::error::Error::ProtocolParseError)?,
.ok_or_else(|| crate::error::Error::ProtocolParseError("Event::Response expected argument".into()))?,
error
.map(Err)
.or_else(|| body.map(Ok))
.ok_or(crate::error::Error::ProtocolParseError)?,
.ok_or_else(|| crate::error::Error::ProtocolParseError("Event::Response expected body".into()))?,
),
ProtocolHeaderActionKind::Synchronized => {
Event::Synchronized(error.map(Err).unwrap_or(Ok(())))
}
ProtocolHeaderActionKind::Error => {
Event::Error(error.ok_or(crate::error::Error::ProtocolParseError)?)
Event::Error(error.ok_or_else(|| crate::error::Error::ProtocolParseError("Event::Error expected body".into()))?)
}
ProtocolHeaderActionKind::BatchComplete => Event::Finished(self.header.argument),
ProtocolHeaderActionKind::BatchComplete => Event::BatchComplete(type_tag.unwrap(), self.header.argument),
ProtocolHeaderActionKind::Flushed => Event::Flushed,
ProtocolHeaderActionKind::Exited => Event::Exited,
ProtocolHeaderActionKind::Ready => Event::Ready,
@ -354,7 +356,9 @@ impl<'a> ProtocolPacket<'a> {
ProtocolHeaderActionKind::Send => {
let req = argument.is_some();
let mid = self.header.argument;
let body = body.ok_or(crate::error::Error::ProtocolParseError)?;
let body = body.ok_or_else(|| crate::error::Error::ProtocolParseError(
format!("Action::Send[{:?}] expected body", type_tag)
))?;
return Ok(ProtocolItem::Send(mid, body, req));
},

View File

@ -1,6 +1,7 @@
// #[cfg(feature = "quic")]
mod quic;
mod redis;
mod tcp;
use futures::Stream;
use messagebus::{error::GenericError, Event, Message, TypeTag};
@ -8,21 +9,34 @@ use std::{collections::HashMap, pin::Pin};
// #[cfg(feature = "quic")]
pub use quic::*;
pub use tcp::*;
pub(crate) type GenericEventStream =
Pin<Box<dyn Stream<Item = Event<Box<dyn Message>, GenericError>> + Send>>;
#[derive(Debug, Default)]
pub struct MessageTable {
table: HashMap<TypeTag, Vec<(TypeTag, TypeTag)>>,
table: HashMap<TypeTag, Vec<Option<(TypeTag, TypeTag)>>>,
}
impl MessageTable {
pub fn add(&mut self, req: TypeTag, resp: TypeTag, err: TypeTag) {
#[inline]
pub fn add_any(&mut self, req: TypeTag, resp_err: Option<(TypeTag, TypeTag)>) {
self.table
.entry(req)
.or_insert_with(Vec::new)
.push((resp, err));
.push(resp_err);
}
#[inline]
pub fn add_request(&mut self, req: TypeTag, resp: TypeTag, err: TypeTag) {
self.add_any(req, Some((resp, err)));
}
#[inline]
pub fn add_message(&mut self, msg: TypeTag) {
self.add_any(msg, None);
}
pub fn iter_keys(&self) -> impl Iterator<Item = &str> + '_ {
@ -30,28 +44,47 @@ impl MessageTable {
.map(|k|k.as_ref())
}
pub fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.table.get(msg).map_or(false, |v| {
v.iter().any(|(r, e)| {
resp.map_or(true, |resp| resp.as_ref() == r.as_ref())
&& err.map_or(true, |err| err.as_ref() == e.as_ref())
#[inline]
pub fn accept_any(&self, msg: &TypeTag) -> bool {
self.table.contains_key(msg)
}
#[inline]
pub fn accept_message(&self, msg: &TypeTag) -> bool {
self.table
.get(msg)
.map_or(false, |v| {
v
.iter()
.any(Option::is_none)
})
}
pub fn accept_request(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.table.get(msg).map_or(false, |v| {
v
.iter()
.filter_map(Option::as_ref)
.any(|(r, e)| {
resp.map_or(true, |resp| resp.as_ref() == r.as_ref())
&& err.map_or(true, |err| err.as_ref() == e.as_ref())
})
})
}
pub fn iter_types(&self) -> impl Iterator<Item = (&'_ TypeTag, &'_ TypeTag, &'_ TypeTag)> + '_ {
pub fn iter_types(&self) -> impl Iterator<Item = (&'_ TypeTag, Option<&'_ (TypeTag, TypeTag)>)> + '_ {
self.table
.iter()
.map(|(k, v)| v.iter().map(move |(e, r)| (k, r, e)))
.map(|(k, v)| v.iter().map(move |resp| (k, resp.as_ref())))
.flatten()
}
}
impl From<Vec<(TypeTag, TypeTag, TypeTag)>> for MessageTable {
fn from(table: Vec<(TypeTag, TypeTag, TypeTag)>) -> Self {
impl From<Vec<(TypeTag, Option<(TypeTag, TypeTag)>)>> for MessageTable {
fn from(table: Vec<(TypeTag, Option<(TypeTag, TypeTag)>)>) -> Self {
let mut outgoing_table = MessageTable::default();
for (x, y, z) in table {
outgoing_table.add(x, y, z);
for (key, val) in table {
outgoing_table.add_any(key, val);
}
outgoing_table
}

View File

@ -4,10 +4,10 @@ mod server;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
pub use client::QuicClientEndpoint;
use messagebus::{Action, Bus, Event, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept};
use messagebus::error::GenericError;
use messagebus::{Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept};
use parking_lot::Mutex;
use quinn::{Connecting, IncomingBiStreams};
pub use server::QuicServerEndpoint;
@ -23,36 +23,34 @@ pub type QuicClientRelay = QuicRelay<QuicClientEndpoint>;
pub type QuicServerRelay = QuicRelay<QuicServerEndpoint>;
use super::{GenericEventStream, MessageTable};
pub type MessageList = Vec<(TypeTag, TypeTag, TypeTag)>;
pub type MessageList = Vec<(TypeTag, Option<(TypeTag, TypeTag)>)>;
pub trait WaitIdle<'a>: Sync {
type Fut: Future + Send + 'a;
fn wait_idle(&'a self) -> Self::Fut;
}
#[derive(Debug)]
enum RecvDo {
Pause,
Ready,
Closed,
Incoming(IncomingBiStreams),
}
pub struct QuicRelay<B> {
base: Mutex<Option<B>>,
self_id: Arc<AtomicU64>,
in_table: MessageTable,
_out_table: MessageTable,
item_sender: UnboundedSender<Option<ProtocolItem>>,
item_receiver: Mutex<Option<UnboundedReceiver<Option<ProtocolItem>>>>,
event_sender: UnboundedSender<RecvDo>,
event_receiver: Mutex<Option<UnboundedReceiver<RecvDo>>>,
event_sender: UnboundedSender<EventBoxed<GenericError>>,
event_receiver: Mutex<Option<UnboundedReceiver<EventBoxed<GenericError>>>>,
stream_sender: UnboundedSender<IncomingBiStreams>,
stream_receiver: Mutex<Option<UnboundedReceiver<IncomingBiStreams>>>,
}
impl QuicRelay<QuicClientEndpoint> {
pub fn new(cert: &str, addr: SocketAddr, host: String, table: (MessageList, MessageList)) -> Result<Self, crate::error::Error> {
let (item_sender, item_receiver) = mpsc::unbounded_channel();
let (event_sender, event_receiver) = mpsc::unbounded_channel();
let (stream_sender, stream_receiver) = mpsc::unbounded_channel();
Ok(QuicRelay {
base: Mutex::new(Some(QuicClientEndpoint::new(cert, addr, host)?)),
@ -63,6 +61,8 @@ impl QuicRelay<QuicClientEndpoint> {
item_receiver: Mutex::new(Some(item_receiver)),
event_sender,
event_receiver: Mutex::new(Some(event_receiver)),
stream_sender,
stream_receiver: Mutex::new(Some(stream_receiver)),
})
}
}
@ -71,6 +71,7 @@ impl QuicRelay<QuicServerEndpoint> {
pub fn new(key_path: &str, cert_path: &str, addr: SocketAddr, table: (MessageList, MessageList)) -> Result<Self, crate::error::Error> {
let (item_sender, item_receiver) = mpsc::unbounded_channel();
let (event_sender, event_receiver) = mpsc::unbounded_channel();
let (stream_sender, stream_receiver) = mpsc::unbounded_channel();
Ok(QuicRelay {
base: Mutex::new(Some(QuicServerEndpoint::new(key_path, cert_path, &addr )?)),
@ -81,6 +82,8 @@ impl QuicRelay<QuicServerEndpoint> {
item_receiver: Mutex::new(Some(item_receiver)),
event_sender,
event_receiver: Mutex::new(Some(event_receiver)),
stream_sender,
stream_receiver: Mutex::new(Some(stream_receiver)),
})
}
}
@ -88,18 +91,18 @@ impl QuicRelay<QuicServerEndpoint> {
impl<B> TypeTagAccept for QuicRelay<B>
where B: Stream<Item = Connecting> + Send + 'static
{
fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.in_table.accept(msg, resp, err)
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_> {
let iter = self.in_table.iter_types();
Box::new(iter.map(|(x, y)| (x.clone(), y.cloned())))
}
fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool) {
let iter = self.in_table.iter_types();
fn accept_msg(&self, msg: &TypeTag) -> bool{
self.in_table.accept_message(msg)
}
for (m, r, e) in iter {
if cb(m, r, e) {
return;
}
}
fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.in_table.accept_request(req, resp, err)
}
}
@ -109,8 +112,13 @@ impl<B> SendUntypedReceiver for QuicRelay<B>
fn send(&self, msg: Action, _bus: &Bus) -> Result<(), messagebus::error::Error<Action>> {
match msg {
Action::Init(self_id) => {
let mut rx = match self.item_receiver.lock().take() {
Some(x) => x,
None => return Ok(()),
};
let stream_sender = self.stream_sender.clone();
let event_sender = self.event_sender.clone();
let mut rx = self.item_receiver.lock().take().unwrap();
let incoming = self.base.lock().take().unwrap();
self.self_id.store(self_id, Ordering::SeqCst);
@ -124,10 +132,11 @@ impl<B> SendUntypedReceiver for QuicRelay<B>
loop {
println!("begin");
let conn = match incoming.next().await {
Some(x) => x,
None => {
log::warn!("No more connections. Message {:?} has been lost!", item);
println!("No more connections. Message {:?} has been lost!", item);
break;
}
};
@ -135,42 +144,51 @@ impl<B> SendUntypedReceiver for QuicRelay<B>
let conn = match conn.await {
Ok(conn) => conn,
Err(err) => {
log::warn!("connection dropped with err {}. waiting next connection", err);
println!("connection dropped with err {}. waiting next connection", err);
continue;
}
};
event_sender.send(RecvDo::Ready).unwrap();
event_sender.send(RecvDo::Incoming(conn.bi_streams)).unwrap();
stream_sender.send(conn.bi_streams).unwrap();
event_sender.send(Event::Ready).unwrap();
loop {
let r = if let Some(r) = item.take() {
r
} else {
match tokio::time::timeout(Duration::from_secs(1), rx.recv()).await {
Ok(Some(Some(r))) => r,
Ok(None) | Ok(Some(None)) => {
match rx.recv().await {
Some(Some(r)) => r,
None | Some(None) => {
println!("closing");
conn.connection.close(0u32.into(), b"done");
incoming.wait_idle().await;
break;
},
Err(_) => {
println!("PING");
let (mut send, _) = match conn.connection.open_bi().await {
Ok(x) => x,
Err(err) => {
println!("err {}", err);
break;
}
};
let _ = send.finish().await;
continue;
}
}
};
header_buff.clear();
body_buff.clear();
// match tokio::time::timeout(Duration::from_secs(5), rx.recv()).await {
// Ok(Some(Some(r))) => r,
// Ok(None) | Ok(Some(None)) => {
// conn.connection.close(0u32.into(), b"done");
// incoming.wait_idle().await;
// break;
// },
// Err(_) => {
// // println!("{:?}", err);
// // let (mut send, _) = match conn.connection.open_bi().await {
// // Ok(x) => x,
// // Err(err) => {
// // println!("err {}", err);
// // break;
// // }
// // };
// // println!("<< PING");
// // let _ = send.write_all(b"PING").await.unwrap();
// // let _ = send.finish().await;
// continue;
// }
// }
};
let (mut send, _) = match conn.connection.open_bi().await {
Ok(x) => x,
@ -180,33 +198,34 @@ impl<B> SendUntypedReceiver for QuicRelay<B>
}
};
body_buff.clear();
let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap();
header_buff.put(&b"MBUS"[..]);
header_buff.put_u16(1);
header_buff.put_u16(0);
header_buff.put_u64(header_buff.len() as _);
serde_cbor::to_writer(&mut header_buff, &pkt).unwrap();
let result = send.write_all(&header_buff).await;
let result = if result.is_ok() {
send.finish().await
} else {
result
};
if let Err(err) = result {
header_buff.resize(16, 0);
serde_cbor::to_writer(&mut header_buff, &pkt).unwrap();
let body_size = header_buff.len() - 16;
let mut head = &mut header_buff[0..16];
head.put(&b"MBUS"[..]);
head.put_u16(1);
head.put_u16(0);
head.put_u64(body_size as _);
if let Err(err) = send.write_all(&header_buff).await {
item = Some(r);
log::warn!("broken connection err {}. try with next connection", err);
println!("write broken connection err {}. try with next connection", err);
break;
}
println!("1");
if let Err(err) = send.finish().await {
item = Some(r);
println!("finish broken connection err {}. try with next connection", err);
break;
}
}
println!("2");
event_sender.send(RecvDo::Pause).unwrap();
event_sender.send(Event::Pause).unwrap();
}
println!("exit main loop");
@ -215,7 +234,7 @@ impl<B> SendUntypedReceiver for QuicRelay<B>
Action::Close => {
self.item_sender.send(None).unwrap();
self.event_sender.send(RecvDo::Closed).unwrap();
self.event_sender.send(Event::Exited).unwrap();
}
other => self.item_sender.send(Some(other.into())).unwrap(),
@ -252,115 +271,159 @@ impl<B> ReciveUntypedReceiver for QuicRelay<B>
fn event_stream(&self, bus: Bus) -> Self::Stream {
let self_id = self.self_id.clone();
let mut recv_stream = self.stream_receiver.lock().take().unwrap();
let mut recv_events = self.event_receiver.lock().take().unwrap();
let sender = self.item_sender.clone();
let mut recv = self.event_receiver.lock().take().unwrap();
Box::pin(
futures::stream::poll_fn(move |cx|recv.poll_recv(cx))
.map(move |recv_do| {
let self_id = self_id.clone();
let bus = bus.clone();
let sender = sender.clone();
let stream1 = futures::stream::poll_fn(move |cx|recv_stream.poll_recv(cx))
.map(move |incoming| {
let buff: Vec<u8> = Vec::with_capacity(1024);
let bus = bus.clone();
let self_id = self_id.clone();
let sender = sender.clone();
match recv_do {
RecvDo::Incoming(incoming) => {
futures::stream::unfold((incoming, bus, sender, self_id), |(mut incoming, bus, sender, self_id)| async move {
loop {
let (_, recv) = match incoming.next().await? {
Ok(recv) => recv,
Err(err) => {
println!("error: {}", err);
return None;
}
};
let buff = recv
.read_to_end(usize::max_value())
.await
.unwrap();
// assert_eq!(&buff[0..4], b"MBUS");
futures::stream::unfold((incoming, bus, sender, self_id, buff), |(mut incoming, bus, sender, self_id, mut buff)| async move {
loop {
let (_, mut recv) = match incoming.next().await? {
Ok(recv) => recv,
Err(err) => {
println!("error: {}", err);
return None;
}
};
if buff.is_empty() {
println!("PONG");
continue;
}
let mut reader = &buff[4..];
let version = reader.get_u16();
let content_type = reader.get_u16();
let body_size = reader.get_u64();
println!("inbound packet {}: v: {}; ct: {}; bs: {}", String::from_utf8_lossy(&buff[0..4]), version, content_type, body_size);
let event = match content_type {
0 => { // CBOR
let proto: ProtocolPacket = serde_cbor::from_slice(&buff[16..]).unwrap();
match proto.deserialize(&bus).unwrap() {
ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()),
ProtocolItem::Action(action) => {
match action {
Action::Close => {
println!("warning: Close recevied - ignoring!");
sender.send(Some(ProtocolItem::Event(Event::Exited))).unwrap();
},
Action::Flush => {
bus.flush().await;
sender.send(Some(ProtocolItem::Event(Event::Flushed))).unwrap();
},
Action::Sync => {
bus.sync().await;
sender.send(Some(ProtocolItem::Event(Event::Synchronized(Ok(()))))).unwrap();
},
Action::Init(..) => (),
Action::Stats => (),
_ => (),
}
continue;
}
ProtocolItem::Send(mid, msg, req) => {
if req {
let res = bus.request_boxed(
msg.upcast_box(),
SendOptions::Except(self_id.load(Ordering::SeqCst))
)
.await
.map(|x|x.as_shared_boxed().unwrap())
.map_err(|x|x.map_msg(|_|()));
sender.send(Some(ProtocolItem::Event(Event::Response(mid, res)))).unwrap();
} else {
let _ = bus.send_boxed(msg.upcast_box(), Default::default())
.await;
}
continue;
}
_ => unimplemented!()
}
},
_ => unimplemented!()
};
buff.resize(4, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
continue;
}
let verb = match std::str::from_utf8(&buff[0..4]) {
Ok(m) => m,
Err(err) => {
println!("recv err parse: {}", err);
continue;
}
};
return Some((event, (incoming, bus, sender, self_id)));
}
}).right_stream()
if verb == "PING" {
println!(">> PING");
continue;
}
other => futures::stream::once(async move {
match other {
RecvDo::Pause => Event::Pause,
RecvDo::Ready => Event::Ready,
RecvDo::Closed => Event::Exited,
_ => unreachable!()
}
}).left_stream()
}
if verb != "MBUS" {
println!("Not MBUS packet!");
continue;
}
buff.resize(12, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
continue;
}
let mut reader = &buff[..];
let version = reader.get_u16();
let content_type = reader.get_u16();
let body_size = reader.get_u64();
buff.resize(body_size as _, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
continue;
}
drop(recv);
// println!("inbound packet MBUS v: {}; ct: {}; bs: {}",
// version, content_type, body_size);
let event = match content_type {
0 => { // CBOR
let proto: ProtocolPacket = match serde_cbor::from_slice(&buff[..]) {
Ok(val) => val,
Err(err) => {
println!("pkt parse err: {}", err);
continue;
},
};
let item: ProtocolItem = match proto.deserialize(&bus) {
Ok(val) => val,
Err(err) => {
println!("item parse err: {}", err);
continue;
},
};
match item {
ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()),
ProtocolItem::Action(action) => {
match action {
Action::Close => {
println!("warning: Close recevied - ignoring!");
sender.send(Some(ProtocolItem::Event(Event::Exited))).unwrap();
},
Action::Flush => {
println!("flush");
bus.flush().await;
sender.send(Some(ProtocolItem::Event(Event::Flushed))).unwrap();
},
Action::Sync => {
println!("flush");
bus.sync().await;
sender.send(Some(ProtocolItem::Event(Event::Synchronized(Ok(()))))).unwrap();
},
Action::Init(..) => (),
Action::Stats => (),
_ => (),
}
continue;
}
ProtocolItem::Send(mid, msg, req) => {
let self_id = self_id.clone();
let sender = sender.clone();
let bus = bus.clone();
let _ = tokio::spawn(async move {
if req {
let res = bus.request_boxed(
msg.upcast_box(),
SendOptions::Except(self_id.load(Ordering::SeqCst))
)
.await
.map(|x|x.as_shared_boxed().unwrap())
.map_err(|x|x.map_msg(|_|()));
sender.send(Some(ProtocolItem::Event(Event::Response(mid, res)))).unwrap();
} else {
let tt = msg.type_tag();
let _ = bus.send_boxed(msg.upcast_box(), Default::default())
.await;
sender.send(Some(ProtocolItem::Event(Event::BatchComplete(tt, 1)))).unwrap();
}
});
continue;
}
_ => unimplemented!()
}
},
_ => unimplemented!()
};
return Some((event, (incoming, bus, sender, self_id, buff)));
}
})
.flatten()
})
.flatten();
let stream2 = futures::stream::poll_fn(move |cx|recv_events.poll_recv(cx));
Box::pin(
futures::stream::select(stream1, stream2)
.take_while(|x| futures::future::ready(!matches!(x, Event::Exited)))
)
}
}

View File

@ -31,7 +31,7 @@ pub struct RedisRelay {
}
impl RedisRelay {
pub fn new(path: &str, table: Vec<(TypeTag, TypeTag, TypeTag)>) -> Result<Self, crate::error::Error> {
pub fn new(path: &str, table: Vec<(TypeTag, Option<(TypeTag, TypeTag)>)>) -> Result<Self, crate::error::Error> {
let client = redis::Client::open(path)?;
let (item_sender, item_receiver) = mpsc::unbounded_channel();
@ -50,18 +50,17 @@ impl RedisRelay {
}
impl TypeTagAccept for RedisRelay {
fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.table.accept(msg, resp, err)
fn accept_msg(&self, msg: &TypeTag) -> bool {
self.table.accept_message(msg)
}
fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool) {
let iter = self.table.iter_types();
fn accept_req(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.table.accept_request(msg, resp, err)
}
for (m, r, e) in iter {
if cb(m, r, e) {
return;
}
}
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_> {
let iter = self.table.iter_types();
Box::new(iter.map(|(x, y)| (x.clone(), y.cloned())))
}
}

View File

@ -0,0 +1,388 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::{Buf, BufMut};
use futures::{Stream, StreamExt, pin_mut};
use futures::stream::unfold;
use messagebus::{Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept};
use messagebus::error::GenericError;
use parking_lot::Mutex;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use crate::proto::{BodyType, ProtocolItem, ProtocolPacket};
use super::{GenericEventStream, MessageList, MessageTable};
pub struct TcpRelay {
server_mode: bool,
addr: SocketAddr,
self_id: Arc<AtomicU64>,
in_table: MessageTable,
// _out_table: MessageTable,
item_sender: UnboundedSender<Option<ProtocolItem>>,
item_receiver: Mutex<Option<UnboundedReceiver<Option<ProtocolItem>>>>,
event_sender: UnboundedSender<EventBoxed<GenericError>>,
event_receiver: Mutex<Option<UnboundedReceiver<EventBoxed<GenericError>>>>,
stream_sender: UnboundedSender<OwnedReadHalf>,
stream_receiver: Mutex<Option<UnboundedReceiver<OwnedReadHalf>>>,
}
impl TcpRelay {
pub fn new(server_mode: bool, addr: SocketAddr, table: (MessageList, MessageList)) -> Self {
let (item_sender, item_receiver) = mpsc::unbounded_channel();
let (event_sender, event_receiver) = mpsc::unbounded_channel();
let (stream_sender, stream_receiver) = mpsc::unbounded_channel();
Self {
self_id: Arc::new(AtomicU64::new(0)),
server_mode,
addr,
item_sender,
in_table: MessageTable::from(table.0),
item_receiver: Mutex::new(Some(item_receiver)),
event_sender,
event_receiver: Mutex::new(Some(event_receiver)),
stream_sender,
stream_receiver: Mutex::new(Some(stream_receiver)),
}
}
fn connections(&self) -> impl Stream<Item = TcpRelayConnection> {
unfold((self.server_mode, self.addr), move |(sm, addr)| async move {
let stream = if sm {
let bind_res = TcpListener::bind(addr).await;
let listener = match bind_res {
Err(err) => {
println!("bind error: {}", err);
return None;
}
Ok(listener) => listener,
};
unfold((listener, ), move |(listener, )| async move {
let (stream, _addr) = match listener.accept().await {
Err(err) => {
println!("accept error: {}", err);
return None;
}
Ok(listener) => listener,
};
Some((TcpRelayConnection::from(stream), (listener, )))
}).left_stream()
} else {
unfold((addr, ), move |(addr, )| async move {
let stream = match TcpStream::connect(addr).await {
Err(err) => {
println!("connect error: {}", err);
return None;
}
Ok(listener) => listener,
};
Some((TcpRelayConnection::from(stream), (addr, )))
}).right_stream()
};
Some((stream, (sm, addr)))
})
.flatten()
}
}
struct TcpRelayConnection {
recv: OwnedReadHalf,
send: OwnedWriteHalf,
}
impl From<TcpStream> for TcpRelayConnection {
fn from(stream: TcpStream) -> Self {
let (recv, send) = stream.into_split();
TcpRelayConnection {
recv,
send
}
}
}
impl TypeTagAccept for TcpRelay {
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_> {
let iter = self.in_table.iter_types();
Box::new(iter.map(|(x, y)| (x.clone(), y.cloned())))
}
fn accept_msg(&self, msg: &TypeTag) -> bool{
self.in_table.accept_message(msg)
}
fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.in_table.accept_request(req, resp, err)
}
}
impl SendUntypedReceiver for TcpRelay {
fn send(&self, msg: Action, _bus: &Bus) -> Result<(), messagebus::error::Error<Action>> {
match msg {
Action::Init(self_id) => {
let mut rx = match self.item_receiver.lock().take() {
Some(x) => x,
None => return Ok(()),
};
let stream_sender = self.stream_sender.clone();
let event_sender = self.event_sender.clone();
let incoming = self.connections();
self.self_id.store(self_id, Ordering::SeqCst);
tokio::spawn(async move {
pin_mut!(incoming);
let mut body_buff = Vec::new();
let mut header_buff = Vec::new();
let mut item = None;
loop {
println!("begin");
let mut conn = match incoming.next().await {
Some(x) => x,
None => {
println!("No more connections. Message {:?} has been lost!", item);
break;
}
};
stream_sender.send(conn.recv).unwrap();
event_sender.send(Event::Ready).unwrap();
loop {
let r = if let Some(r) = item.take() {
r
} else {
match rx.recv().await {
Some(Some(r)) => r,
None | Some(None) => {
println!("closing");
drop(conn.send);
break;
},
}
};
body_buff.clear();
let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap();
header_buff.resize(16, 0);
serde_cbor::to_writer(&mut header_buff, &pkt).unwrap();
let body_size = header_buff.len() - 16;
let mut head = &mut header_buff[0..16];
head.put(&b"MBUS"[..]);
head.put_u16(1);
head.put_u16(0);
head.put_u64(body_size as _);
if let Err(err) = conn.send.write_all(&header_buff).await {
item = Some(r);
println!("write broken connection err {}. try with next connection", err);
break;
}
}
event_sender.send(Event::Pause).unwrap();
}
println!("exit main loop");
});
}
Action::Close => {
self.item_sender.send(None).unwrap();
self.event_sender.send(Event::Exited).unwrap();
}
other => self.item_sender.send(Some(other.into())).unwrap(),
}
Ok(())
}
fn send_msg(
&self,
mid: u64,
msg: Box<dyn Message>,
req: bool,
_bus: &Bus,
) -> Result<(), messagebus::error::Error<Box<dyn Message>>> {
match msg.as_shared_boxed() {
Ok(msg) => {
if let Err(err) = self.item_sender.send(Some((mid, msg, req).into())) {
Err(messagebus::error::Error::TryAgain(err.0.unwrap().unwrap_send().unwrap().1.upcast_box()))
} else {
Ok(())
}
}
Err(msg) => Err(messagebus::error::Error::TryAgain(msg)),
}
}
}
impl ReciveUntypedReceiver for TcpRelay {
type Stream = GenericEventStream;
fn event_stream(&self, bus: Bus) -> Self::Stream {
let self_id = self.self_id.clone();
let mut recv_stream = self.stream_receiver.lock().take().unwrap();
let mut recv_events = self.event_receiver.lock().take().unwrap();
let sender = self.item_sender.clone();
let stream1 = futures::stream::poll_fn(move |cx|recv_stream.poll_recv(cx))
.map(move |incoming| {
let buff: Vec<u8> = Vec::with_capacity(1024);
let bus = bus.clone();
let self_id = self_id.clone();
let sender = sender.clone();
futures::stream::unfold((incoming, bus, sender, self_id, buff), |(mut recv, bus, sender, self_id, mut buff)| async move {
loop {
buff.resize(4, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
break None;
}
if &buff == b"PING" {
println!(">> PING");
continue;
}
if &buff != b"MBUS" {
println!("Not MBUS packet!");
continue;
}
buff.resize(12, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
continue;
}
let mut reader = &buff[..];
let version = reader.get_u16();
let content_type = reader.get_u16();
let body_size = reader.get_u64();
buff.resize(body_size as _, 0);
if let Err(err) = recv.read_exact(&mut buff[..]).await {
println!("recv err: {}", err);
continue;
}
// println!("inbound packet MBUS v: {}; ct: {}; bs: {}",
// version, content_type, body_size);
let event = match content_type {
0 => { // CBOR
let proto: ProtocolPacket = match serde_cbor::from_slice(&buff[..]) {
Ok(val) => val,
Err(err) => {
println!("pkt parse err: {}", err);
continue;
},
};
let item: ProtocolItem = match proto.deserialize(&bus) {
Ok(val) => val,
Err(err) => {
println!("item parse err: {}", err);
continue;
},
};
match item {
ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()),
ProtocolItem::Action(action) => {
match action {
Action::Close => {
println!("warning: Close recevied - ignoring!");
sender.send(Some(ProtocolItem::Event(Event::Exited))).unwrap();
},
Action::Flush => {
println!("flush");
bus.flush().await;
sender.send(Some(ProtocolItem::Event(Event::Flushed))).unwrap();
},
Action::Sync => {
println!("sync");
bus.sync().await;
sender.send(Some(ProtocolItem::Event(Event::Synchronized(Ok(()))))).unwrap();
},
Action::Init(..) => (),
Action::Stats => (),
_ => (),
}
continue;
}
ProtocolItem::Send(mid, msg, req) => {
let self_id = self_id.clone();
let sender = sender.clone();
let bus = bus.clone();
let _ = tokio::spawn(async move {
if req {
let res = bus.request_boxed(
msg.upcast_box(),
SendOptions::Except(self_id.load(Ordering::SeqCst))
)
.await
.map(|x|x.as_shared_boxed().unwrap())
.map_err(|x|x.map_msg(|_|()));
sender.send(Some(ProtocolItem::Event(Event::Response(mid, res)))).unwrap();
} else {
let tt = msg.type_tag();
let _ = bus.send_boxed(msg.upcast_box(), Default::default())
.await;
sender.send(Some(ProtocolItem::Event(Event::BatchComplete(tt, 1)))).unwrap();
}
});
continue;
}
_ => unimplemented!()
}
},
_ => unimplemented!()
};
return Some((event, (recv, bus, sender, self_id, buff)));
}
})
})
.flatten();
let stream2 = futures::stream::poll_fn(move |cx|recv_events.poll_recv(cx));
Box::pin(
futures::stream::select(stream1, stream2)
.take_while(|x| futures::future::ready(!matches!(x, Event::Exited)))
)
}
}

View File

@ -1,12 +1,11 @@
use core::{marker::PhantomData, pin::Pin};
use std::{collections::HashMap, sync::{Arc, atomic::{AtomicU64, Ordering}}};
use std::{collections::{HashMap, HashSet}, sync::{Arc, atomic::{AtomicU64, Ordering}}};
use futures::{Future, FutureExt};
use smallvec::SmallVec;
use tokio::sync::Mutex;
use crate::{AsyncBatchHandler, AsyncBatchSynchronizedHandler, AsyncHandler, AsyncSynchronizedHandler, BatchHandler, BatchSynchronizedHandler, Bus, BusInner, Handler, Message, Relay, SharedMessage, SynchronizedHandler, Untyped, envelop::{IntoSharedMessage, TypeTag}, error::{Error, StdSyncSendError}, receiver::{
use crate::{AsyncBatchHandler, AsyncBatchSynchronizedHandler, AsyncHandler, AsyncSynchronizedHandler, BatchHandler, BatchSynchronizedHandler, Bus, BusInner, Handler, Message, Relay, SharedMessage, SynchronizedHandler, TypeTag, Untyped, envelop::IntoSharedMessage, error::{Error, StdSyncSendError}, receiver::{
BusPollerCallback, Receiver, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver,
UntypedPollerCallback,
}, receivers};
@ -43,19 +42,19 @@ pub struct RegisterEntry<K, T, F, P, B> {
payload: B,
builder: F,
poller: P,
receivers: HashMap<TypeTag, Receiver>,
receivers: HashSet<Receiver>,
pollers: Vec<BusPollerCallback>,
_m: PhantomData<(K, T)>,
}
impl<K, T: 'static, F, P, B> RegisterEntry<K, T, F, P, B>
where
F: FnMut(&mut B, TypeTag, Receiver),
F: FnMut(&mut B, Receiver),
P: FnMut(&mut B, BusPollerCallback),
{
pub fn done(mut self) -> B {
for (tid, v) in self.receivers {
(self.builder)(&mut self.payload, tid, v);
for r in self.receivers {
(self.builder)(&mut self.payload, r);
}
for p in self.pollers {
@ -79,7 +78,7 @@ impl<T, F, P, B> RegisterEntry<UnsyncEntry, T, F, P, B> {
let receiver = Receiver::new::<M, R, E, S>(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), queue, true, inner);
let poller2 = receiver.start_polling();
self.receivers.insert(M::type_tag_(), receiver);
self.receivers.insert(receiver);
self.pollers.push(poller(self.item.clone()));
self.pollers.push(poller2);
@ -148,7 +147,7 @@ impl<T, F, P, B> RegisterEntry<SyncEntry, T, F, P, B> {
let receiver = Receiver::new::<M, R, E, S>(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), queue, true, inner);
let poller2 = receiver.start_polling();
self.receivers.insert(M::type_tag_(), receiver);
self.receivers.insert(receiver);
self.pollers.push(poller(self.item.clone()));
self.pollers.push(poller2);
@ -222,7 +221,7 @@ impl MessageTypeDescriptor {
#[derive(Default)]
pub struct Module {
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
receivers: HashMap<TypeTag, SmallVec<[Receiver; 4]>>,
receivers: HashSet<Receiver>,
pollings: Vec<BusPollerCallback>,
}
@ -230,7 +229,7 @@ impl Module {
pub fn new() -> Self {
Self {
message_types: HashMap::new(),
receivers: HashMap::new(),
receivers: HashSet::new(),
pollings: Vec::new(),
}
}
@ -254,28 +253,7 @@ impl Module {
pub fn register_relay<S: Relay + Send + Sync + 'static>(mut self, inner: S) -> Self {
let receiver = Receiver::new_relay::<S>(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), inner);
self.pollings.push(receiver.start_polling());
let mut receiver_added = false;
receiver.iter_types(&mut |msg, _, _| {
self.receivers
.entry(msg.clone())
.or_insert_with(SmallVec::new)
.push(receiver.clone());
if !receiver_added {
receiver_added = true;
false
} else {
true
}
});
if !receiver_added {
self.receivers
.entry("".into())
.or_insert_with(SmallVec::new)
.push(receiver);
}
self.receivers.insert(receiver);
self
}
@ -286,18 +264,16 @@ impl Module {
) -> RegisterEntry<
SyncEntry,
T,
impl FnMut(&mut Self, TypeTag, Receiver),
impl FnMut(&mut Self, Receiver),
impl FnMut(&mut Self, Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>>),
Self,
> {
RegisterEntry {
item: Arc::new(item) as Untyped,
payload: self,
builder: |p: &mut Self, tt, r| {
p.receivers.entry(tt).or_insert_with(SmallVec::new).push(r);
},
builder: |p: &mut Self, r| { p.receivers.insert(r); },
poller: |p: &mut Self, poller| p.pollings.push(poller),
receivers: HashMap::new(),
receivers: HashSet::new(),
pollers: Vec::new(),
_m: Default::default(),
}
@ -309,7 +285,7 @@ impl Module {
) -> RegisterEntry<
UnsyncEntry,
T,
impl FnMut(&mut Self, TypeTag, Receiver),
impl FnMut(&mut Self, Receiver),
impl FnMut(&mut Self, Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>>),
Self,
> {
@ -318,11 +294,9 @@ impl Module {
RegisterEntry {
item,
payload: self,
builder: |p: &mut Self, tt, r| {
p.receivers.entry(tt).or_insert_with(SmallVec::new).push(r);
},
builder: |p: &mut Self, r| { p.receivers.insert(r); },
poller: |p: &mut Self, poller| p.pollings.push(poller),
receivers: HashMap::new(),
receivers: HashSet::new(),
pollers: Vec::new(),
_m: Default::default(),
}
@ -331,13 +305,7 @@ impl Module {
pub fn add_module(mut self, module: Module) -> Self {
self.message_types.extend(module.message_types);
self.pollings.extend(module.pollings);
for (k, v) in module.receivers {
self.receivers
.entry(k)
.or_insert_with(Default::default)
.extend(v);
}
self.receivers.extend(module.receivers);
self
}
@ -376,22 +344,16 @@ impl BusBuilder {
) -> RegisterEntry<
SyncEntry,
T,
impl FnMut(&mut Self, TypeTag, Receiver),
impl FnMut(&mut Self, Receiver),
impl FnMut(&mut Self, Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>>),
Self,
> {
RegisterEntry {
item: Arc::new(item) as Untyped,
payload: self,
builder: |p: &mut Self, tt, r| {
p.inner
.receivers
.entry(tt)
.or_insert_with(SmallVec::new)
.push(r);
},
builder: |p: &mut Self, r| { p.inner.receivers.insert(r); },
poller: |p: &mut Self, poller| p.inner.pollings.push(poller),
receivers: HashMap::new(),
receivers: HashSet::new(),
pollers: Vec::new(),
_m: Default::default(),
}
@ -403,22 +365,16 @@ impl BusBuilder {
) -> RegisterEntry<
UnsyncEntry,
T,
impl FnMut(&mut Self, TypeTag, Receiver),
impl FnMut(&mut Self, Receiver),
impl FnMut(&mut Self, Box<dyn FnOnce(Bus) -> Pin<Box<dyn Future<Output = ()> + Send>>>),
Self,
> {
RegisterEntry {
item: Arc::new(Mutex::new(item)) as Untyped,
payload: self,
builder: |p: &mut Self, tt, r| {
p.inner
.receivers
.entry(tt)
.or_insert_with(SmallVec::new)
.push(r);
},
builder: |p: &mut Self, r| { p.inner.receivers.insert(r); },
poller: |p: &mut Self, poller| p.inner.pollings.push(poller),
receivers: HashMap::new(),
receivers: HashSet::new(),
pollers: Vec::new(),
_m: Default::default(),
}
@ -431,19 +387,8 @@ impl BusBuilder {
}
pub fn build(self) -> (Bus, impl Future<Output = ()>) {
let mut receivers = HashMap::new();
for (key, values) in self.inner.receivers {
for v in values {
receivers
.entry(key.clone())
.or_insert_with(SmallVec::new)
.push(v);
}
}
let bus = Bus {
inner: Arc::new(BusInner::new(receivers, self.inner.message_types)),
inner: Arc::new(BusInner::new(self.inner.receivers, self.inner.message_types)),
};
let mut futs = Vec::with_capacity(self.inner.pollings.len() * 2);

View File

@ -22,7 +22,7 @@ use core::{
time::Duration,
};
use smallvec::SmallVec;
use std::{collections::HashMap, sync::Arc};
use std::{collections::{HashMap, HashSet}, sync::Arc};
use tokio::sync::Mutex;
use builder::{BusBuilder, MessageTypeDescriptor};
@ -41,6 +41,7 @@ pub use receiver::{
pub use relay::Relay;
pub type Untyped = Arc<dyn Any + Send + Sync>;
type LookupQuery = (TypeTag, Option<TypeTag>, Option<TypeTag>);
static ID_COUNTER: AtomicU64 = AtomicU64::new(1);
@ -60,20 +61,46 @@ impl Default for SendOptions {
}
pub struct BusInner {
receivers: HashSet<Receiver>,
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
receivers: HashMap<TypeTag, SmallVec<[Receiver; 4]>>,
lookup: HashMap<LookupQuery, SmallVec<[Receiver; 4]>>,
closed: AtomicBool,
maintain: Mutex<()>,
}
impl BusInner {
pub(crate) fn new(
receivers: HashMap<TypeTag, SmallVec<[Receiver; 4]>>,
receivers: HashSet<Receiver>,
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
) -> Self {
let mut lookup = HashMap::new();
for recv in receivers.iter() {
for (msg, resp) in recv.iter_types() {
lookup.entry((msg.clone(), None, None))
.or_insert_with(HashSet::new)
.insert(recv.clone());
if let Some((resp, err)) = resp {
lookup.entry((msg.clone(), Some(resp.clone()), None))
.or_insert_with(HashSet::new)
.insert(recv.clone());
lookup.entry((msg, Some(resp), Some(err)))
.or_insert_with(HashSet::new)
.insert(recv.clone());
}
}
}
let lookup = lookup
.into_iter()
.map(|(k, v)| (k, v.into_iter().collect()))
.collect();
Self {
message_types,
receivers,
lookup,
closed: AtomicBool::new(false),
maintain: Mutex::new(()),
}
@ -96,18 +123,14 @@ impl Bus {
}
pub(crate) fn init(&self) {
for rs in self.inner.receivers.values() {
for r in rs {
r.init(self).unwrap();
}
for r in self.inner.receivers.iter() {
r.init(self).unwrap();
}
}
pub async fn ready(&self) {
for rs in self.inner.receivers.values() {
for r in rs {
r.ready().await;
}
for r in self.inner.receivers.iter() {
r.ready().await;
}
}
@ -115,13 +138,11 @@ impl Bus {
let _handle = self.inner.maintain.lock().await;
self.inner.closed.store(true, Ordering::SeqCst);
for rs in self.inner.receivers.values() {
for r in rs {
let err = tokio::time::timeout(Duration::from_secs(20), r.close(self)).await;
for r in self.inner.receivers.iter() {
let err = tokio::time::timeout(Duration::from_secs(20), r.close(self)).await;
if let Err(err) = err {
error!("Close timeout on {}: {}", r.name(), err);
}
if let Err(err) = err {
error!("Close timeout on {}: {}", r.name(), err);
}
}
}
@ -134,13 +155,11 @@ impl Bus {
for _ in 0..fuse_count {
iters += 1;
let mut flushed = false;
for rs in self.inner.receivers.values() {
for r in rs {
if r.need_flush() {
flushed = true;
for r in self.inner.receivers.iter() {
if r.need_flush() {
flushed = true;
r.flush(self).await;
}
r.flush(self).await;
}
}
@ -163,10 +182,8 @@ impl Bus {
pub async fn sync(&self) {
let _handle = self.inner.maintain.lock().await;
for rs in self.inner.receivers.values() {
for r in rs {
r.sync(self).await;
}
for r in self.inner.receivers.iter() {
r.sync(self).await;
}
}
@ -207,7 +224,7 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt) {
if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None)) {
let permits = if let Some(x) = self.try_reserve(&tt, rs) {
x
} else {
@ -270,7 +287,7 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt) {
if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None)) {
if let Some((last, head)) = rs.split_last() {
for r in head {
let _ = r.send(self, mid, msg.clone(), false, r.reserve(&tt).await);
@ -304,10 +321,9 @@ impl Bus {
return Err(SendError::Closed(msg).into());
}
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt) {
if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None)) {
if let Some((last, head)) = rs.split_last() {
for r in head {
let _ = r.force_send(self, mid, msg.clone(), false);
@ -336,7 +352,9 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt).and_then(|rs| rs.first()) {
if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None))
.and_then(|rs| rs.first())
{
let permits = if let Some(x) = rs.try_reserve(&tt) {
x
} else {
@ -357,7 +375,9 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt).and_then(|rs| rs.first()) {
if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None))
.and_then(|rs| rs.first())
{
Ok(rs.send(self, mid, msg, false, rs.reserve(&tt).await)?)
} else {
Err(Error::NoReceivers)
@ -374,10 +394,10 @@ impl Bus {
req: M,
options: SendOptions,
) -> Result<R, Error<M>> {
let tid = req.type_tag();
let tid = M::type_tag_();
let rid = R::type_tag_();
let mut iter = self.select_receivers(&tid, options, Some(&rid), None);
let mut iter = self.select_receivers(tid.clone(), options, Some(rid), None, true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc
.add_response_waiter::<R>()
@ -402,7 +422,7 @@ impl Bus {
let rid = R::type_tag_();
let eid = E::type_tag_();
let mut iter = self.select_receivers(&tid, options, Some(&rid), Some(&eid));
let mut iter = self.select_receivers(tid.clone(), options, Some(rid), Some(eid), true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc.add_response_waiter_we::<R, E>().map_err(|x| {
x.map_err(|_| unimplemented!())
@ -436,7 +456,7 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
let mut iter = self.select_receivers(&tt, options, None, None);
let mut iter = self.select_receivers(tt.clone(), options, None, None, false);
let first = iter.next();
for r in iter {
@ -476,7 +496,7 @@ impl Bus {
let tt = msg.type_tag();
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
let mut iter = self.select_receivers(&tt, options, None, None);
let mut iter = self.select_receivers(tt.clone(), options, None, None, false);
if let Some(rs) = iter.next() {
Ok(rs.send_boxed(self, mid, msg, false, rs.reserve(&tt).await)?)
} else {
@ -495,7 +515,7 @@ impl Bus {
let tt = req.type_tag();
let mut iter = self.select_receivers(&tt, options, None, None);
let mut iter = self.select_receivers(tt.clone(), options, None, None, true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc.add_response_waiter_boxed().map_err(|x| {
x.map_err(|_| unimplemented!())
@ -528,7 +548,7 @@ impl Bus {
let tt = req.type_tag();
let eid = E::type_tag_();
let mut iter = self.select_receivers(&tt, options, None, Some(&eid));
let mut iter = self.select_receivers(tt.clone(), options, None, Some(eid), true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc.add_response_waiter_boxed_we().map_err(|x| {
x.map_err(|_| unimplemented!())
@ -563,7 +583,8 @@ impl Bus {
let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if let Some(rs) = self.inner.receivers.get(&tt).and_then(|rs| rs.first()) {
if let Some(rs) = self.inner.lookup.get(&(tt.clone(), None, None))
.and_then(|rs| rs.first()) {
let msg = self.deserialize_message(tt.clone(), de)?;
Ok(rs.send_boxed(self, mid, msg.upcast_box(), false, rs.reserve(&tt).await)?)
} else {
@ -582,7 +603,7 @@ impl Bus {
return Err(Error::NoResponse);
}
let mut iter = self.select_receivers(&tt, options, None, None);
let mut iter = self.select_receivers(tt.clone(), options, None, None, true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc.add_response_waiter_boxed().unwrap();
let msg = self.deserialize_message(tt.clone(), de)?;
@ -620,25 +641,22 @@ impl Bus {
self.inner
.receivers
.iter()
.map(|(_, r)| r.into_iter().map(|x| x.stats()))
.flatten()
.map(|x| x.stats())
}
#[inline]
fn select_receivers<'a, 'b: 'a, 'c: 'a, 'd: 'a>(
&'a self,
tid: &'b TypeTag,
fn select_receivers(
&self,
tid: TypeTag,
options: SendOptions,
rid: Option<&'c TypeTag>,
eid: Option<&'d TypeTag>,
) -> impl Iterator<Item = &Receiver> + 'a {
self.inner
.receivers
.get(tid)
.into_iter()
.map(|item| item.iter())
.flatten()
.filter(move |r| r.accept(tid, rid, eid))
rid: Option<TypeTag>,
eid: Option<TypeTag>,
is_req: bool,
) -> impl Iterator<Item = &Receiver> + '_ {
let vec = self.inner.lookup.get(&(tid.clone(), rid.clone(), eid.clone())).unwrap();
vec.iter()
.filter(move |r| r.accept(is_req, &tid, rid.as_ref(), eid.as_ref()))
.filter(move |r| match options {
SendOptions::Except(id) => id != r.id(),
SendOptions::Direct(id) => id == r.id(),

View File

@ -15,6 +15,7 @@ use core::{
pin::Pin,
sync::atomic::{AtomicBool, AtomicI64, Ordering},
};
use std::hash::{Hash, Hasher};
use futures::{pin_mut, Stream};
use futures::{Future, FutureExt, StreamExt};
use std::{borrow::Cow, sync::Arc};
@ -87,8 +88,9 @@ pub trait WrapperReturnTypeAndError<R: Message, E: StdSyncSendError>: Send + Syn
}
pub trait TypeTagAccept {
fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool;
fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool);
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_>;
fn accept_msg(&self, msg: &TypeTag) -> bool;
fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool;
}
pub trait ReceiverTrait: TypeTagAccept + Send + Sync {
@ -155,7 +157,7 @@ pub type EventBoxed<E> = Event<Box<dyn Message>, E>;
pub enum Event<M, E: StdSyncSendError> {
Response(u64, Result<M, Error<(), E>>),
Synchronized(Result<(), Error<(), E>>),
Finished(u64),
BatchComplete(TypeTag, u64),
Error(Error<(), E>),
InitFailed(Error<(), E>),
Stats(Stats),
@ -170,7 +172,7 @@ impl<M, E: StdSyncSendError> Event<M, E> {
match self {
Event::Response(mid, res) => Event::Response(mid, res.map(f)),
Event::Synchronized(res) => Event::Synchronized(res),
Event::Finished(cnt) => Event::Finished(cnt),
Event::BatchComplete(tt, cnt) => Event::BatchComplete(tt, cnt),
Event::Error(err) => Event::Error(err),
Event::InitFailed(err) => Event::InitFailed(err),
Event::Stats(st) => Event::Stats(st),
@ -256,7 +258,7 @@ where
}
}
Event::Finished(n) => {
Event::BatchComplete(_, n) => {
self.context.processing.fetch_sub(n as _, Ordering::SeqCst);
if n > 1 {
@ -348,11 +350,11 @@ where
E: StdSyncSendError,
S: ReciveTypedReceiver<R, E> + Send + Sync + 'static,
{
fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool) {
let _ = cb(&M::type_tag_(), &R::type_tag_(), &E::type_tag_());
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_> {
Box::new(std::iter::once((M::type_tag_(), Some((R::type_tag_(), E::type_tag_())))))
}
fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
if let Some(resp) = resp {
if resp.as_ref() != R::type_tag_().as_ref() {
return false;
@ -365,6 +367,10 @@ where
}
}
req.as_ref() == M::type_tag_().as_ref()
}
fn accept_msg(&self, msg: &TypeTag) -> bool {
msg.as_ref() == M::type_tag_().as_ref()
}
}
@ -709,6 +715,12 @@ pub struct Receiver {
inner: Arc<dyn ReceiverTrait>,
}
impl Hash for Receiver {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.id().hash(state);
}
}
impl fmt::Debug for Receiver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Receiver({:?})", self.inner.type_id())?;
@ -718,7 +730,7 @@ impl fmt::Debug for Receiver {
impl core::cmp::PartialEq for Receiver {
fn eq(&self, other: &Receiver) -> bool {
self.inner.type_id() == other.inner.type_id()
self.inner.id() == other.inner.id()
}
}
@ -782,8 +794,12 @@ impl Receiver {
}
#[inline]
pub fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.inner.accept(msg, resp, err)
pub fn accept(&self, is_req: bool, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
if is_req {
self.inner.accept_req(msg, resp, err)
} else {
self.inner.accept_msg(msg)
}
}
#[inline]
@ -1037,7 +1053,7 @@ impl Receiver {
}
#[inline]
pub fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool) {
self.inner.iter_types(cb)
pub fn iter_types(&self) -> impl Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_ {
self.inner.iter_types()
}
}

View File

@ -89,12 +89,16 @@ impl<S> TypeTagAccept for RelayWrapper<S>
where
S: Relay + Send + Sync + 'static,
{
fn iter_types(&self, cb: &mut dyn FnMut(&TypeTag, &TypeTag, &TypeTag) -> bool) {
self.inner.iter_types(cb)
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)> + '_> {
self.inner.iter_types()
}
fn accept(&self, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.inner.accept(msg, resp, err)
fn accept_msg(&self, msg: &TypeTag) -> bool {
self.inner.accept_msg(msg)
}
fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool {
self.inner.accept_req(req, resp, err)
}
}
@ -286,6 +290,16 @@ where
}
}
Event::BatchComplete(tt, n) => {
if let Some(ctx) = self.context.receivers.get(&tt) {
ctx.processing.fetch_sub(n, Ordering::SeqCst);
for _ in 0..n {
ctx.response.notify_one();
}
}
}
_ => unimplemented!(),
}
}

View File

@ -2,12 +2,7 @@ use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use messagebus::{
derive::{Error as MbError, Message},
error::{self, GenericError},
receivers, Action, AsyncHandler, Bus, Event, Message, MessageBounds, ReciveUntypedReceiver,
SendUntypedReceiver, TypeTagAccept, TypeTagged,
};
use messagebus::{Action, AsyncHandler, Bus, Event, Message, MessageBounds, ReciveUntypedReceiver, SendUntypedReceiver, TypeTag, TypeTagAccept, TypeTagged, derive::{Error as MbError, Message}, error::{self, GenericError}, receivers};
use parking_lot::Mutex;
use thiserror::Error;
use tokio::sync::mpsc;
@ -56,7 +51,7 @@ pub struct TestRelay {
}
impl TypeTagAccept for TestRelay {
fn accept(
fn accept_req(
&self,
msg: &messagebus::TypeTag,
resp: Option<&messagebus::TypeTag>,
@ -64,52 +59,44 @@ impl TypeTagAccept for TestRelay {
) -> bool {
if msg.as_ref() == Msg::<i16>::type_tag_().as_ref() {
if let Some(resp) = resp {
if resp.as_ref() != Msg::<u8>::type_tag_().as_ref() {
return false;
if resp.as_ref() == Msg::<u8>::type_tag_().as_ref() {
return true
}
}
} else {
return true
}
}
if msg.as_ref() == Msg::<i32>::type_tag_().as_ref() {
if let Some(resp) = resp {
if resp.as_ref() != Msg::<i64>::type_tag_().as_ref()
&& resp.as_ref() != Msg::<()>::type_tag_().as_ref()
{
return false;
if resp.as_ref() == Msg::<u64>::type_tag_().as_ref() {
return true
}
}
} else {
return true
}
}
true
false
}
fn iter_types(
fn accept_msg(
&self,
cb: &mut dyn FnMut(
&messagebus::TypeTag,
&messagebus::TypeTag,
&messagebus::TypeTag,
) -> bool,
) {
if !cb(
&Msg::<i16>::type_tag_(),
&Msg::<u8>::type_tag_(),
&Error::type_tag_(),
) {
return;
msg: &messagebus::TypeTag,
) -> bool {
if msg.as_ref() == Msg::<i32>::type_tag_().as_ref() {
return true
}
if !cb(
&Msg::<i32>::type_tag_(),
&Msg::<()>::type_tag_(),
&Error::type_tag_(),
) {
return;
}
if !cb(
&Msg::<i32>::type_tag_(),
&Msg::<i64>::type_tag_(),
&Error::type_tag_(),
) {}
false
}
fn iter_types(&self) -> Box<dyn Iterator<Item = (TypeTag, Option<(TypeTag, TypeTag)>)>> {
Box::new(
std::iter::once((Msg::<i32>::type_tag_(), None))
.chain(std::iter::once((Msg::<i32>::type_tag_(), Some((Msg::<u64>::type_tag_(), GenericError::type_tag_())))))
.chain(std::iter::once((Msg::<i16>::type_tag_(), Some((Msg::<u8>::type_tag_(), GenericError::type_tag_())))))
)
}
}
@ -139,7 +126,7 @@ impl SendUntypedReceiver for TestRelay {
&self,
mid: u64,
msg: Box<dyn Message>,
_req: bool,
req: bool,
_bus: &Bus,
) -> Result<(), error::Error<Box<dyn Message>>> {
println!("TestRelay::send_msg [{}] {:?}", mid, msg);
@ -147,10 +134,18 @@ impl SendUntypedReceiver for TestRelay {
self.stx
.send(Event::Response(mid, Ok(Box::new(Msg(9u8)))))
.unwrap();
} else if msg.type_tag().as_ref() == Msg::<i32>::type_tag_().as_ref() {
if req {
self.stx
.send(Event::Response(mid, Ok(Box::new(Msg(22u64)))))
.unwrap();
} else {
self.stx
.send(Event::Response(mid, Ok(Box::new(()))))
.unwrap();
}
} else {
self.stx
.send(Event::Response(mid, Ok(Box::new(()))))
.unwrap();
panic!("unsupported message type {}", msg.type_tag());
}
Ok(())
@ -189,9 +184,11 @@ async fn test_relay() {
.build();
b.send(Msg(32i32)).await.unwrap();
let res: Msg<u8> = b.request(Msg(12i16), Default::default()).await.unwrap();
let res1: Msg<u8> = b.request(Msg(12i16), Default::default()).await.unwrap();
let res2: Msg<u64> = b.request(Msg(12i32), Default::default()).await.unwrap();
assert_eq!(res.0, 9u8);
assert_eq!(res1.0, 9u8);
assert_eq!(res2.0, 22u64);
b.flush().await;
b.close().await;