diff --git a/crates/remote/examples/quic_client.rs b/crates/remote/examples/quic_client.rs index c51f735..c738c87 100644 --- a/crates/remote/examples/quic_client.rs +++ b/crates/remote/examples/quic_client.rs @@ -1,15 +1,15 @@ +use messagebus::derive::Message; use messagebus::error::GenericError; use messagebus::{Bus, TypeTagged}; use messagebus_remote::relays::QuicClientRelay; -use serde_derive::{Serialize, Deserialize}; -use messagebus::derive::Message; +use serde_derive::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone, Message)] #[namespace("example")] #[message(shared, clone)] pub struct Req { data: i32, - text: String + text: String, } #[derive(Serialize, Deserialize, Debug, Clone, Message)] @@ -17,39 +17,44 @@ pub struct Req { #[message(shared, clone)] pub struct Resp { data: i32, - text: String + text: String, } #[tokio::main] async fn main() { let relay = QuicClientRelay::new( - "./examples/cert.der", - "127.0.0.1:8083".parse().unwrap(), + "./examples/cert.der", + "127.0.0.1:8083".parse().unwrap(), "localhost".into(), - (vec![ - (Req::type_tag_(), Some((Resp::type_tag_(), GenericError::type_tag_()))) - ], - vec![]) - ).unwrap(); + ( + vec![( + Req::type_tag_(), + Some((Resp::type_tag_(), GenericError::type_tag_())), + )], + vec![], + ), + ) + .unwrap(); - let (b, poller) = Bus::build() - .register_relay(relay) - .build(); + let (b, poller) = Bus::build().register_relay(relay).build(); - b.ready().await; println!("ready"); - let resp: Resp = b.request(Req { - data: 12, - text: String::from("test") - }, Default::default()) - .await - .unwrap(); + let resp: Resp = b + .request( + Req { + data: 12, + text: String::from("test"), + }, + Default::default(), + ) + .await + .unwrap(); println!("resp {:?}", resp); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/crates/remote/examples/tcp_client.rs b/crates/remote/examples/tcp_client.rs index 6bdd211..aafd006 100644 --- a/crates/remote/examples/tcp_client.rs +++ b/crates/remote/examples/tcp_client.rs @@ -46,7 +46,7 @@ async fn main() { println!("resp {:?}", resp); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/crates/remote/src/relays/quic/mod.rs b/crates/remote/src/relays/quic/mod.rs index c5c3abb..13cbe58 100644 --- a/crates/remote/src/relays/quic/mod.rs +++ b/crates/remote/src/relays/quic/mod.rs @@ -2,18 +2,21 @@ mod client; mod server; use std::net::SocketAddr; -use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use bytes::{Buf, BufMut}; pub use client::QuicClientEndpoint; +use futures::{pin_mut, Future, Stream, StreamExt}; use messagebus::error::GenericError; -use messagebus::{Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept}; +use messagebus::{ + Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, + SendUntypedReceiver, TypeTag, TypeTagAccept, +}; use parking_lot::Mutex; use quinn::{Connecting, IncomingBiStreams}; pub use server::QuicServerEndpoint; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use futures::{Future, Stream, StreamExt, pin_mut}; -use bytes::{Buf, BufMut}; pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; @@ -47,7 +50,12 @@ pub struct QuicRelay { } impl QuicRelay { - pub fn new(cert: &str, addr: SocketAddr, host: String, table: (MessageList, MessageList)) -> Result { + pub fn new( + cert: &str, + addr: SocketAddr, + host: String, + table: (MessageList, MessageList), + ) -> Result { let (item_sender, item_receiver) = mpsc::unbounded_channel(); let (event_sender, event_receiver) = mpsc::unbounded_channel(); let (stream_sender, stream_receiver) = mpsc::unbounded_channel(); @@ -68,13 +76,18 @@ impl QuicRelay { } impl QuicRelay { - pub fn new(key_path: &str, cert_path: &str, addr: SocketAddr, table: (MessageList, MessageList)) -> Result { + pub fn new( + key_path: &str, + cert_path: &str, + addr: SocketAddr, + table: (MessageList, MessageList), + ) -> Result { let (item_sender, item_receiver) = mpsc::unbounded_channel(); let (event_sender, event_receiver) = mpsc::unbounded_channel(); let (stream_sender, stream_receiver) = mpsc::unbounded_channel(); Ok(QuicRelay { - base: Mutex::new(Some(QuicServerEndpoint::new(key_path, cert_path, &addr )?)), + base: Mutex::new(Some(QuicServerEndpoint::new(key_path, cert_path, &addr)?)), self_id: Arc::new(AtomicU64::new(0)), in_table: MessageTable::from(table.0), _out_table: MessageTable::from(table.1), @@ -88,16 +101,16 @@ impl QuicRelay { } } -impl TypeTagAccept for QuicRelay -where B: Stream + Send + 'static +impl TypeTagAccept for QuicRelay +where + B: Stream + Send + 'static, { - fn iter_types(&self) -> Box)> + '_> { let iter = self.in_table.iter_types(); Box::new(iter.map(|(x, y)| (x.clone(), y.cloned()))) } - fn accept_msg(&self, msg: &TypeTag) -> bool{ + fn accept_msg(&self, msg: &TypeTag) -> bool { self.in_table.accept_message(msg) } @@ -106,8 +119,9 @@ where B: Stream + Send + 'static } } -impl SendUntypedReceiver for QuicRelay - where B: for<'a> WaitIdle<'a> + Stream + Send + 'static +impl SendUntypedReceiver for QuicRelay +where + B: for<'a> WaitIdle<'a> + Stream + Send + 'static, { fn send(&self, msg: Action, _bus: &Bus) -> Result<(), messagebus::error::Error> { match msg { @@ -129,7 +143,7 @@ impl SendUntypedReceiver for QuicRelay let mut body_buff = Vec::new(); let mut header_buff = Vec::new(); let mut item = None; - + loop { println!("begin"); @@ -144,7 +158,10 @@ impl SendUntypedReceiver for QuicRelay let conn = match conn.await { Ok(conn) => conn, Err(err) => { - println!("connection dropped with err {}. waiting next connection", err); + println!( + "connection dropped with err {}. waiting next connection", + err + ); continue; } }; @@ -163,7 +180,7 @@ impl SendUntypedReceiver for QuicRelay conn.connection.close(0u32.into(), b"done"); incoming.wait_idle().await; break; - }, + } } // match tokio::time::timeout(Duration::from_secs(5), rx.recv()).await { @@ -189,7 +206,7 @@ impl SendUntypedReceiver for QuicRelay // } // } }; - + let (mut send, _) = match conn.connection.open_bi().await { Ok(x) => x, Err(err) => { @@ -200,7 +217,7 @@ impl SendUntypedReceiver for QuicRelay body_buff.clear(); let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap(); - + header_buff.resize(16, 0); serde_cbor::to_writer(&mut header_buff, &pkt).unwrap(); let body_size = header_buff.len() - 16; @@ -214,13 +231,19 @@ impl SendUntypedReceiver for QuicRelay if let Err(err) = send.write_all(&header_buff).await { item = Some(r); - println!("write broken connection err {}. try with next connection", err); + println!( + "write broken connection err {}. try with next connection", + err + ); break; } if let Err(err) = send.finish().await { item = Some(r); - println!("finish broken connection err {}. try with next connection", err); + println!( + "finish broken connection err {}. try with next connection", + err + ); break; } } @@ -253,177 +276,215 @@ impl SendUntypedReceiver for QuicRelay match msg.as_shared_boxed() { Ok(msg) => { if let Err(err) = self.item_sender.send(Some((mid, msg, req).into())) { - Err(messagebus::error::Error::TryAgain(err.0.unwrap().unwrap_send().unwrap().1.upcast_box())) + Err(messagebus::error::Error::TryAgain( + err.0.unwrap().unwrap_send().unwrap().1.upcast_box(), + )) } else { Ok(()) } } - + Err(msg) => Err(messagebus::error::Error::TryAgain(msg)), } } } -impl ReciveUntypedReceiver for QuicRelay - where B: Send +impl ReciveUntypedReceiver for QuicRelay +where + B: Send, { type Stream = GenericEventStream; fn event_stream(&self, bus: Bus) -> Self::Stream { let self_id = self.self_id.clone(); - + let mut recv_stream = self.stream_receiver.lock().take().unwrap(); let mut recv_events = self.event_receiver.lock().take().unwrap(); let sender = self.item_sender.clone(); - let stream1 = futures::stream::poll_fn(move |cx|recv_stream.poll_recv(cx)) + let stream1 = futures::stream::poll_fn(move |cx| recv_stream.poll_recv(cx)) .map(move |incoming| { let buff: Vec = Vec::with_capacity(1024); let bus = bus.clone(); let self_id = self_id.clone(); let sender = sender.clone(); - futures::stream::unfold((incoming, bus, sender, self_id, buff), |(mut incoming, bus, sender, self_id, mut buff)| async move { - loop { - let (_, mut recv) = match incoming.next().await? { - Ok(recv) => recv, - Err(err) => { - println!("error: {}", err); - return None; - } - }; + futures::stream::unfold( + (incoming, bus, sender, self_id, buff), + |(mut incoming, bus, sender, self_id, mut buff)| async move { + loop { + let (_, mut recv) = match incoming.next().await? { + Ok(recv) => recv, + Err(err) => { + println!("error: {}", err); + return None; + } + }; - buff.resize(4, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - continue; - } - - let verb = match std::str::from_utf8(&buff[0..4]) { - Ok(m) => m, - Err(err) => { - println!("recv err parse: {}", err); + buff.resize(4, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); continue; } - }; - if verb == "PING" { - println!(">> PING"); - continue; - } - - if verb != "MBUS" { - println!("Not MBUS packet!"); - continue; - } - - buff.resize(12, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - continue; - } - - let mut reader = &buff[..]; - let version = reader.get_u16(); - let content_type = reader.get_u16(); - let body_size = reader.get_u64(); - - buff.resize(body_size as _, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - continue; - } - - drop(recv); - - // println!("inbound packet MBUS v: {}; ct: {}; bs: {}", - // version, content_type, body_size); - - let event = match content_type { - 0 => { // CBOR - let proto: ProtocolPacket = match serde_cbor::from_slice(&buff[..]) { - Ok(val) => val, - Err(err) => { - println!("pkt parse err: {}", err); - continue; - }, - }; - - let item: ProtocolItem = match proto.deserialize(&bus) { - Ok(val) => val, - Err(err) => { - println!("item parse err: {}", err); - continue; - }, - }; - - match item { - ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), - ProtocolItem::Action(action) => { - match action { - Action::Close => { - println!("warning: Close recevied - ignoring!"); - sender.send(Some(ProtocolItem::Event(Event::Exited))).unwrap(); - }, - Action::Flush => { - println!("flush"); - bus.flush().await; - sender.send(Some(ProtocolItem::Event(Event::Flushed))).unwrap(); - }, - Action::Sync => { - println!("flush"); - bus.sync().await; - sender.send(Some(ProtocolItem::Event(Event::Synchronized(Ok(()))))).unwrap(); - }, - Action::Init(..) => (), - Action::Stats => (), - _ => (), - } - continue; - } - ProtocolItem::Send(mid, msg, req) => { - let self_id = self_id.clone(); - let sender = sender.clone(); - let bus = bus.clone(); - - let _ = tokio::spawn(async move { - if req { - let res = bus.request_boxed( - msg.upcast_box(), - SendOptions::Except(self_id.load(Ordering::SeqCst)) - ) - .await - .map(|x|x.as_shared_boxed().unwrap()) - .map_err(|x|x.map_msg(|_|())); - - sender.send(Some(ProtocolItem::Event(Event::Response(mid, res)))).unwrap(); - } else { - let tt = msg.type_tag(); - let _ = bus.send_boxed(msg.upcast_box(), Default::default()) - .await; - - sender.send(Some(ProtocolItem::Event(Event::BatchComplete(tt, 1)))).unwrap(); - } - }); - - continue; - } - _ => unimplemented!() + let verb = match std::str::from_utf8(&buff[0..4]) { + Ok(m) => m, + Err(err) => { + println!("recv err parse: {}", err); + continue; } - }, - _ => unimplemented!() - }; + }; - return Some((event, (incoming, bus, sender, self_id, buff))); - } - }) + if verb == "PING" { + println!(">> PING"); + continue; + } + + if verb != "MBUS" { + println!("Not MBUS packet!"); + continue; + } + + buff.resize(12, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); + continue; + } + + let mut reader = &buff[..]; + let version = reader.get_u16(); + let content_type = reader.get_u16(); + let body_size = reader.get_u64(); + + buff.resize(body_size as _, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); + continue; + } + + drop(recv); + + // println!("inbound packet MBUS v: {}; ct: {}; bs: {}", + // version, content_type, body_size); + + let event = match content_type { + 0 => { + // CBOR + let proto: ProtocolPacket = + match serde_cbor::from_slice(&buff[..]) { + Ok(val) => val, + Err(err) => { + println!("pkt parse err: {}", err); + continue; + } + }; + + let item: ProtocolItem = match proto.deserialize(&bus) { + Ok(val) => val, + Err(err) => { + println!("item parse err: {}", err); + continue; + } + }; + + match item { + ProtocolItem::Event(ev) => { + ev.map_msg(|msg| msg.upcast_box()) + } + ProtocolItem::Action(action) => { + match action { + Action::Close => { + println!("warning: Close recevied - ignoring!"); + sender + .send(Some(ProtocolItem::Event( + Event::Exited, + ))) + .unwrap(); + } + Action::Flush => { + println!("flush"); + bus.flush_all().await; + sender + .send(Some(ProtocolItem::Event( + Event::Flushed, + ))) + .unwrap(); + } + Action::Sync => { + println!("flush"); + bus.sync_all().await; + sender + .send(Some(ProtocolItem::Event( + Event::Synchronized(Ok(())), + ))) + .unwrap(); + } + Action::Init(..) => (), + Action::Stats => (), + _ => (), + } + continue; + } + ProtocolItem::Send(mid, msg, req) => { + let self_id = self_id.clone(); + let sender = sender.clone(); + let bus = bus.clone(); + + let _ = tokio::spawn(async move { + if req { + let res = bus + .request_boxed( + msg.upcast_box(), + SendOptions::Except( + self_id.load(Ordering::SeqCst), + ), + ) + .await + .map(|x| x.as_shared_boxed().unwrap()) + .map_err(|x| x.map_msg(|_| ())); + + sender + .send(Some(ProtocolItem::Event( + Event::Response(mid, res), + ))) + .unwrap(); + } else { + let tt = msg.type_tag(); + let _ = bus + .send_boxed( + msg.upcast_box(), + Default::default(), + ) + .await; + + sender + .send(Some(ProtocolItem::Event( + Event::BatchComplete(tt, 1), + ))) + .unwrap(); + } + }); + + continue; + } + _ => unimplemented!(), + } + } + _ => unimplemented!(), + }; + + return Some((event, (incoming, bus, sender, self_id, buff))); + } + }, + ) }) .flatten(); - let stream2 = futures::stream::poll_fn(move |cx|recv_events.poll_recv(cx)); + let stream2 = futures::stream::poll_fn(move |cx| recv_events.poll_recv(cx)); Box::pin( futures::stream::select(stream1, stream2) - .take_while(|x| futures::future::ready(!matches!(x, Event::Exited))) + .take_while(|x| futures::future::ready(!matches!(x, Event::Exited))), ) } -} \ No newline at end of file +} + diff --git a/crates/remote/src/relays/tcp/mod.rs b/crates/remote/src/relays/tcp/mod.rs index 10e6ceb..a322070 100644 --- a/crates/remote/src/relays/tcp/mod.rs +++ b/crates/remote/src/relays/tcp/mod.rs @@ -1,12 +1,15 @@ use std::net::SocketAddr; -use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use bytes::{Buf, BufMut}; -use futures::{Stream, StreamExt, pin_mut}; use futures::stream::unfold; -use messagebus::{Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept}; +use futures::{pin_mut, Stream, StreamExt}; use messagebus::error::GenericError; +use messagebus::{ + Action, Bus, Event, EventBoxed, Message, ReciveUntypedReceiver, SendOptions, + SendUntypedReceiver, TypeTag, TypeTagAccept, +}; use parking_lot::Mutex; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; @@ -17,15 +20,13 @@ use crate::proto::{BodyType, ProtocolItem, ProtocolPacket}; use super::{GenericEventStream, MessageList, MessageTable}; - pub struct TcpRelay { server_mode: bool, addr: SocketAddr, - + self_id: Arc, in_table: MessageTable, // _out_table: MessageTable, - item_sender: UnboundedSender>, item_receiver: Mutex>>>, @@ -41,7 +42,7 @@ impl TcpRelay { let (item_sender, item_receiver) = mpsc::unbounded_channel(); let (event_sender, event_receiver) = mpsc::unbounded_channel(); let (stream_sender, stream_receiver) = mpsc::unbounded_channel(); - + Self { self_id: Arc::new(AtomicU64::new(0)), server_mode, @@ -56,65 +57,66 @@ impl TcpRelay { } } - fn connections(&self) -> impl Stream { - unfold((self.server_mode, self.addr), move |(sm, addr)| async move { - let stream = if sm { - let bind_res = TcpListener::bind(addr).await; - let listener = match bind_res { - Err(err) => { - println!("bind error: {}", err); - return None; - } - - Ok(listener) => listener, + fn connections(&self) -> impl Stream { + unfold( + (self.server_mode, self.addr), + move |(sm, addr)| async move { + let stream = if sm { + let bind_res = TcpListener::bind(addr).await; + let listener = match bind_res { + Err(err) => { + println!("bind error: {}", err); + return None; + } + + Ok(listener) => listener, + }; + + unfold((listener,), move |(listener,)| async move { + let (stream, _addr) = match listener.accept().await { + Err(err) => { + println!("accept error: {}", err); + return None; + } + + Ok(listener) => listener, + }; + + Some((TcpRelayConnection::from(stream), (listener,))) + }) + .left_stream() + } else { + unfold((addr,), move |(addr,)| async move { + let stream = match TcpStream::connect(addr).await { + Err(err) => { + println!("connect error: {}", err); + return None; + } + + Ok(listener) => listener, + }; + + Some((TcpRelayConnection::from(stream), (addr,))) + }) + .right_stream() }; - unfold((listener, ), move |(listener, )| async move { - let (stream, _addr) = match listener.accept().await { - Err(err) => { - println!("accept error: {}", err); - return None; - } - - Ok(listener) => listener, - }; - - Some((TcpRelayConnection::from(stream), (listener, ))) - }).left_stream() - } else { - unfold((addr, ), move |(addr, )| async move { - let stream = match TcpStream::connect(addr).await { - Err(err) => { - println!("connect error: {}", err); - return None; - } - - Ok(listener) => listener, - }; - - Some((TcpRelayConnection::from(stream), (addr, ))) - }).right_stream() - }; - - Some((stream, (sm, addr))) - }) + Some((stream, (sm, addr))) + }, + ) .flatten() } } - struct TcpRelayConnection { - recv: OwnedReadHalf, + recv: OwnedReadHalf, send: OwnedWriteHalf, } impl From for TcpRelayConnection { fn from(stream: TcpStream) -> Self { let (recv, send) = stream.into_split(); - TcpRelayConnection { - recv, - send - } + TcpRelayConnection { recv, send } } } @@ -124,7 +126,7 @@ impl TypeTagAccept for TcpRelay { Box::new(iter.map(|(x, y)| (x.clone(), y.cloned()))) } - fn accept_msg(&self, msg: &TypeTag) -> bool{ + fn accept_msg(&self, msg: &TypeTag) -> bool { self.in_table.accept_message(msg) } @@ -154,7 +156,7 @@ impl SendUntypedReceiver for TcpRelay { let mut body_buff = Vec::new(); let mut header_buff = Vec::new(); let mut item = None; - + loop { println!("begin"); @@ -179,13 +181,13 @@ impl SendUntypedReceiver for TcpRelay { println!("closing"); drop(conn.send); break; - }, + } } }; body_buff.clear(); let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap(); - + header_buff.resize(16, 0); serde_cbor::to_writer(&mut header_buff, &pkt).unwrap(); let body_size = header_buff.len() - 16; @@ -199,7 +201,10 @@ impl SendUntypedReceiver for TcpRelay { if let Err(err) = conn.send.write_all(&header_buff).await { item = Some(r); - println!("write broken connection err {}. try with next connection", err); + println!( + "write broken connection err {}. try with next connection", + err + ); break; } } @@ -232,12 +237,14 @@ impl SendUntypedReceiver for TcpRelay { match msg.as_shared_boxed() { Ok(msg) => { if let Err(err) = self.item_sender.send(Some((mid, msg, req).into())) { - Err(messagebus::error::Error::TryAgain(err.0.unwrap().unwrap_send().unwrap().1.upcast_box())) + Err(messagebus::error::Error::TryAgain( + err.0.unwrap().unwrap_send().unwrap().1.upcast_box(), + )) } else { Ok(()) } } - + Err(msg) => Err(messagebus::error::Error::TryAgain(msg)), } } @@ -248,141 +255,176 @@ impl ReciveUntypedReceiver for TcpRelay { fn event_stream(&self, bus: Bus) -> Self::Stream { let self_id = self.self_id.clone(); - + let mut recv_stream = self.stream_receiver.lock().take().unwrap(); let mut recv_events = self.event_receiver.lock().take().unwrap(); let sender = self.item_sender.clone(); - let stream1 = futures::stream::poll_fn(move |cx|recv_stream.poll_recv(cx)) + let stream1 = futures::stream::poll_fn(move |cx| recv_stream.poll_recv(cx)) .map(move |incoming| { let buff: Vec = Vec::with_capacity(1024); let bus = bus.clone(); let self_id = self_id.clone(); let sender = sender.clone(); - futures::stream::unfold((incoming, bus, sender, self_id, buff), |(mut recv, bus, sender, self_id, mut buff)| async move { - loop { - buff.resize(4, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - break None; - } + futures::stream::unfold( + (incoming, bus, sender, self_id, buff), + |(mut recv, bus, sender, self_id, mut buff)| async move { + loop { + buff.resize(4, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); + break None; + } - if &buff == b"PING" { - println!(">> PING"); - continue; - } + if &buff == b"PING" { + println!(">> PING"); + continue; + } - if &buff != b"MBUS" { - println!("Not MBUS packet!"); - continue; - } + if &buff != b"MBUS" { + println!("Not MBUS packet!"); + continue; + } - buff.resize(12, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - continue; - } + buff.resize(12, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); + continue; + } - let mut reader = &buff[..]; - let version = reader.get_u16(); - let content_type = reader.get_u16(); - let body_size = reader.get_u64(); + let mut reader = &buff[..]; + let version = reader.get_u16(); + let content_type = reader.get_u16(); + let body_size = reader.get_u64(); - buff.resize(body_size as _, 0); - if let Err(err) = recv.read_exact(&mut buff[..]).await { - println!("recv err: {}", err); - continue; - } + buff.resize(body_size as _, 0); + if let Err(err) = recv.read_exact(&mut buff[..]).await { + println!("recv err: {}", err); + continue; + } - // println!("inbound packet MBUS v: {}; ct: {}; bs: {}", - // version, content_type, body_size); + // println!("inbound packet MBUS v: {}; ct: {}; bs: {}", + // version, content_type, body_size); - let event = match content_type { - 0 => { // CBOR - let proto: ProtocolPacket = match serde_cbor::from_slice(&buff[..]) { - Ok(val) => val, - Err(err) => { - println!("pkt parse err: {}", err); - continue; - }, - }; - - let item: ProtocolItem = match proto.deserialize(&bus) { - Ok(val) => val, - Err(err) => { - println!("item parse err: {}", err); - continue; - }, - }; - - match item { - ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), - ProtocolItem::Action(action) => { - match action { - Action::Close => { - println!("warning: Close recevied - ignoring!"); - sender.send(Some(ProtocolItem::Event(Event::Exited))).unwrap(); - }, - Action::Flush => { - println!("flush"); - bus.flush().await; - sender.send(Some(ProtocolItem::Event(Event::Flushed))).unwrap(); - }, - Action::Sync => { - println!("sync"); - bus.sync().await; - sender.send(Some(ProtocolItem::Event(Event::Synchronized(Ok(()))))).unwrap(); - }, - Action::Init(..) => (), - Action::Stats => (), - _ => (), - } - continue; - } - ProtocolItem::Send(mid, msg, req) => { - let self_id = self_id.clone(); - let sender = sender.clone(); - let bus = bus.clone(); - - let _ = tokio::spawn(async move { - if req { - let res = bus.request_boxed( - msg.upcast_box(), - SendOptions::Except(self_id.load(Ordering::SeqCst)) - ) - .await - .map(|x|x.as_shared_boxed().unwrap()) - .map_err(|x|x.map_msg(|_|())); - - sender.send(Some(ProtocolItem::Event(Event::Response(mid, res)))).unwrap(); - } else { - let tt = msg.type_tag(); - let _ = bus.send_boxed(msg.upcast_box(), Default::default()) - .await; - - sender.send(Some(ProtocolItem::Event(Event::BatchComplete(tt, 1)))).unwrap(); + let event = match content_type { + 0 => { + // CBOR + let proto: ProtocolPacket = + match serde_cbor::from_slice(&buff[..]) { + Ok(val) => val, + Err(err) => { + println!("pkt parse err: {}", err); + continue; } - }); + }; - continue; + let item: ProtocolItem = match proto.deserialize(&bus) { + Ok(val) => val, + Err(err) => { + println!("item parse err: {}", err); + continue; + } + }; + + match item { + ProtocolItem::Event(ev) => { + ev.map_msg(|msg| msg.upcast_box()) + } + ProtocolItem::Action(action) => { + match action { + Action::Close => { + println!("warning: Close recevied - ignoring!"); + sender + .send(Some(ProtocolItem::Event( + Event::Exited, + ))) + .unwrap(); + } + Action::Flush => { + println!("flush"); + bus.flush_all().await; + sender + .send(Some(ProtocolItem::Event( + Event::Flushed, + ))) + .unwrap(); + } + Action::Sync => { + println!("sync"); + bus.sync_all().await; + sender + .send(Some(ProtocolItem::Event( + Event::Synchronized(Ok(())), + ))) + .unwrap(); + } + Action::Init(..) => (), + Action::Stats => (), + _ => (), + } + continue; + } + ProtocolItem::Send(mid, msg, req) => { + let self_id = self_id.clone(); + let sender = sender.clone(); + let bus = bus.clone(); + + let _ = tokio::spawn(async move { + if req { + let res = bus + .request_boxed( + msg.upcast_box(), + SendOptions::Except( + self_id.load(Ordering::SeqCst), + ), + ) + .await + .map(|x| x.as_shared_boxed().unwrap()) + .map_err(|x| x.map_msg(|_| ())); + + sender + .send(Some(ProtocolItem::Event( + Event::Response(mid, res), + ))) + .unwrap(); + } else { + let tt = msg.type_tag(); + let _ = bus + .send_boxed( + msg.upcast_box(), + Default::default(), + ) + .await; + + sender + .send(Some(ProtocolItem::Event( + Event::BatchComplete(tt, 1), + ))) + .unwrap(); + } + }); + + continue; + } + _ => unimplemented!(), } - _ => unimplemented!() } - }, - _ => unimplemented!() - }; + _ => unimplemented!(), + }; - return Some((event, (recv, bus, sender, self_id, buff))); - } - }) + return Some((event, (recv, bus, sender, self_id, buff))); + } + }, + ) }) .flatten(); - let stream2 = futures::stream::poll_fn(move |cx|recv_events.poll_recv(cx)); + let stream2 = futures::stream::poll_fn(move |cx| recv_events.poll_recv(cx)); Box::pin( futures::stream::select(stream1, stream2) - .take_while(|x| futures::future::ready(!matches!(x, Event::Exited))) + .take_while(|x| futures::future::ready(!matches!(x, Event::Exited))), ) } -} \ No newline at end of file +} + diff --git a/examples/benchmark.rs b/examples/benchmark.rs index 458b639..78dae11 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -117,7 +117,7 @@ async fn iter(bus: &Bus) { bus.send(MsgF32(0.)).await.unwrap(); } - bus.flush().await; + bus.flush_all().await; } #[tokio::main] @@ -156,7 +156,7 @@ async fn main() { println!("Avg time: {:.4}", time_sum as f64 / (count as f64 * 1000.0)); println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/demo_async.rs b/examples/demo_async.rs index 2d44ae6..d1779b3 100644 --- a/examples/demo_async.rs +++ b/examples/demo_async.rs @@ -187,7 +187,7 @@ async fn main() { b.send(MsgF32(0.)).await.unwrap(); println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/demo_boxed.rs b/examples/demo_boxed.rs index 4d033e7..d3f4344 100644 --- a/examples/demo_boxed.rs +++ b/examples/demo_boxed.rs @@ -190,7 +190,7 @@ async fn main() { b.send(MsgF32(0f32)).await.unwrap(); println!("flush"); - b.flush().await; + b.flush_all().await; println!("sending boxed variant"); @@ -199,7 +199,7 @@ async fn main() { .unwrap(); println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/demo_relay.rs b/examples/demo_relay.rs index e209d79..d46adbd 100644 --- a/examples/demo_relay.rs +++ b/examples/demo_relay.rs @@ -45,7 +45,7 @@ async fn main() { // b. println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/demo_slow.rs b/examples/demo_slow.rs index 973db87..f4ce7f7 100644 --- a/examples/demo_slow.rs +++ b/examples/demo_slow.rs @@ -80,7 +80,7 @@ async fn main() { b.send(MsgU32(32u32)).await.unwrap(); println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/demo_sync_batch.rs b/examples/demo_sync_batch.rs index 841e37b..2789569 100644 --- a/examples/demo_sync_batch.rs +++ b/examples/demo_sync_batch.rs @@ -81,7 +81,7 @@ async fn main() { b.send(MsgI16(7i16)).await.unwrap(); println!("flush"); - b.flush().await; + b.flush_all().await; println!("close"); b.close().await; diff --git a/examples/non_sync.rs b/examples/non_sync.rs index 83fe489..21de329 100644 --- a/examples/non_sync.rs +++ b/examples/non_sync.rs @@ -83,7 +83,7 @@ async fn main() { println!("flush"); - b.flush().await; + b.flush_all().await; println!("closing"); diff --git a/src/lib.rs b/src/lib.rs index 6970176..bc10c2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,10 +9,9 @@ mod stats; mod trait_object; pub mod type_tag; - pub mod __reexport { - pub use serde; pub use ctor; + pub use serde; } #[macro_use] @@ -29,7 +28,10 @@ use core::{ time::Duration, }; use smallvec::SmallVec; -use std::{collections::{HashMap, HashSet}, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use tokio::sync::Mutex; use builder::BusBuilder; @@ -39,6 +41,7 @@ use stats::Stats; // public pub use builder::Module; +pub use ctor; pub use envelop::{IntoBoxedMessage, Message, MessageBounds, SharedMessage, TypeTag, TypeTagged}; pub use handler::*; pub use receiver::{ @@ -46,8 +49,7 @@ pub use receiver::{ SendUntypedReceiver, TypeTagAccept, }; pub use relay::Relay; -pub use ctor; -pub use type_tag::{register_shared_message, deserialize_shared_message}; +pub use type_tag::{deserialize_shared_message, register_shared_message}; pub type Untyped = Arc; type LookupQuery = (TypeTag, Option, Option); @@ -77,26 +79,28 @@ pub struct BusInner { } impl BusInner { - pub(crate) fn new( - receivers: HashSet, - ) -> Self { + pub(crate) fn new(receivers: HashSet) -> Self { let mut lookup = HashMap::new(); for recv in receivers.iter() { for (msg, resp) in recv.iter_types() { - lookup.entry((msg.clone(), None, None)) + lookup + .entry((msg.clone(), None, None)) .or_insert_with(HashSet::new) .insert(recv.clone()); if let Some((resp, err)) = resp { - lookup.entry((msg.clone(), Some(resp.clone()), None)) + lookup + .entry((msg.clone(), Some(resp.clone()), None)) .or_insert_with(HashSet::new) .insert(recv.clone()); - lookup.entry((msg.clone(), None, Some(err.clone()))) + lookup + .entry((msg.clone(), None, Some(err.clone()))) .or_insert_with(HashSet::new) .insert(recv.clone()); - lookup.entry((msg, Some(resp), Some(err))) + lookup + .entry((msg, Some(resp), Some(err))) .or_insert_with(HashSet::new) .insert(recv.clone()); } @@ -157,7 +161,7 @@ impl Bus { } } - pub async fn flush(&self) { + pub async fn flush_all(&self) { let _handle = self.inner.maintain.lock().await; let fuse_count = 32i32; let mut breaked = false; @@ -189,7 +193,80 @@ impl Bus { } } - pub async fn sync(&self) { + pub async fn flush(&self) { + let _handle = self.inner.maintain.lock().await; + let fuse_count = 32i32; + let mut breaked = false; + let mut iters = 0usize; + + for _ in 0..fuse_count { + let receivers = + self.select_receivers(M::type_tag_(), Default::default(), None, None, false); + iters += 1; + let mut flushed = false; + for r in receivers { + if r.need_flush() { + flushed = true; + + r.flush(self).await; + } + } + + if !flushed { + breaked = true; + break; + } + } + + if !breaked { + warn!( + "!!! WARNING: unable to reach equilibrium in {} iterations !!!", + fuse_count + ); + } else { + info!("flushed in {} iterations !!!", iters); + } + } + + pub async fn flush2(&self) { + let _handle = self.inner.maintain.lock().await; + let fuse_count = 32i32; + let mut breaked = false; + let mut iters = 0usize; + for _ in 0..fuse_count { + let receivers1 = + self.select_receivers(M1::type_tag_(), Default::default(), None, None, false); + + let receivers2 = + self.select_receivers(M2::type_tag_(), Default::default(), None, None, false); + + iters += 1; + let mut flushed = false; + for r in receivers1.chain(receivers2) { + if r.need_flush() { + flushed = true; + + r.flush(self).await; + } + } + + if !flushed { + breaked = true; + break; + } + } + + if !breaked { + warn!( + "!!! WARNING: unable to reach equilibrium in {} iterations !!!", + fuse_count + ); + } else { + info!("flushed in {} iterations !!!", iters); + } + } + + pub async fn sync_all(&self) { let _handle = self.inner.maintain.lock().await; for r in self.inner.receivers.iter() { @@ -197,12 +274,45 @@ impl Bus { } } - #[inline] - pub async fn flush_and_sync(&self) { - self.flush().await; - self.sync().await; + pub async fn sync(&self) { + let _handle = self.inner.maintain.lock().await; + let receivers = + self.select_receivers(M::type_tag_(), Default::default(), None, None, false); + + for r in receivers { + r.sync(self).await; + } } + pub async fn sync2(&self) { + let _handle = self.inner.maintain.lock().await; + + let receivers1 = + self.select_receivers(M1::type_tag_(), Default::default(), None, None, false); + + let receivers2 = + self.select_receivers(M2::type_tag_(), Default::default(), None, None, false); + + for r in receivers1.chain(receivers2) { + r.sync(self).await; + } + } + + #[inline] + pub async fn flush_and_sync_all(&self) { + self.flush_all().await; + self.sync_all().await; + } + #[inline] + pub async fn flush_and_sync(&self) { + self.flush::().await; + self.sync::().await; + } + #[inline] + pub async fn flush_and_sync2(&self) { + self.flush2::().await; + self.sync2::().await; + } fn try_reserve(&self, tt: &TypeTag, rs: &[Receiver]) -> Option> { let mut permits = SmallVec::<[Permit; 32]>::new(); @@ -362,8 +472,11 @@ impl Bus { let tt = msg.type_tag(); let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None)) - .and_then(|rs| rs.first()) + if let Some(rs) = self + .inner + .lookup + .get(&(msg.type_tag(), None, None)) + .and_then(|rs| rs.first()) { let permits = if let Some(x) = rs.try_reserve(&tt) { x @@ -385,8 +498,11 @@ impl Bus { let tt = msg.type_tag(); let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - if let Some(rs) = self.inner.lookup.get(&(msg.type_tag(), None, None)) - .and_then(|rs| rs.first()) + if let Some(rs) = self + .inner + .lookup + .get(&(msg.type_tag(), None, None)) + .and_then(|rs| rs.first()) { Ok(rs.send(self, mid, msg, false, rs.reserve(&tt).await)?) } else { @@ -593,11 +709,14 @@ impl Bus { let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - if let Some(rs) = self.inner.lookup.get(&(tt.clone(), None, None)) - .and_then(|rs| rs.first()) { - + if let Some(rs) = self + .inner + .lookup + .get(&(tt.clone(), None, None)) + .and_then(|rs| rs.first()) + { let msg = deserialize_shared_message(tt.clone(), de)?; - + Ok(rs.send_boxed(self, mid, msg.upcast_box(), false, rs.reserve(&tt).await)?) } else { Err(Error::NoReceivers) @@ -635,10 +754,7 @@ impl Bus { } pub fn stats(&self) -> impl Iterator + '_ { - self.inner - .receivers - .iter() - .map(|x| x.stats()) + self.inner.receivers.iter().map(|x| x.stats()) } #[inline] @@ -650,14 +766,16 @@ impl Bus { eid: Option, is_req: bool, ) -> impl Iterator + '_ { - self.inner.lookup.get(&(tid.clone(), rid.clone(), eid.clone())) + self.inner + .lookup + .get(&(tid.clone(), rid.clone(), eid.clone())) .into_iter() .flatten() .filter(move |r| r.accept(is_req, &tid, rid.as_ref(), eid.as_ref())) .filter(move |r| match options { SendOptions::Except(id) => id != r.id(), SendOptions::Direct(id) => id == r.id(), - _ => true + _ => true, }) } } diff --git a/src/receiver.rs b/src/receiver.rs index 687bdfb..8f98fd4 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -15,9 +15,9 @@ use core::{ pin::Pin, sync::atomic::{AtomicBool, AtomicI64, Ordering}, }; -use std::hash::{Hash, Hasher}; use futures::{pin_mut, Stream}; use futures::{Future, FutureExt, StreamExt}; +use std::hash::{Hash, Hasher}; use std::{borrow::Cow, sync::Arc}; use tokio::sync::{oneshot, Notify}; @@ -251,7 +251,7 @@ where if self.context.resend_unused_resp { // TODO } - }, + } Ok(None) => (), Err(err) => error!("Response Error: {}", err), @@ -350,8 +350,11 @@ where E: StdSyncSendError, S: ReciveTypedReceiver + Send + Sync + 'static, { - fn iter_types(&self) -> Box)> + '_> { - Box::new(std::iter::once((M::type_tag_(), Some((R::type_tag_(), E::type_tag_()))))) + fn iter_types(&self) -> Box)> + '_> { + Box::new(std::iter::once(( + M::type_tag_(), + Some((R::type_tag_(), E::type_tag_())), + ))) } fn accept_req(&self, req: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool { @@ -411,7 +414,7 @@ where .map_err(|_| Error::MessageCastError)?; SendTypedReceiver::send(&self.inner, mid, *boxed, req, bus) - .map_err(|err| err.map_msg(|m|m.into_boxed())) + .map_err(|err| err.map_msg(|m| m.into_boxed())) } fn stats(&self) -> Stats { @@ -794,7 +797,13 @@ impl Receiver { } #[inline] - pub fn accept(&self, is_req: bool, msg: &TypeTag, resp: Option<&TypeTag>, err: Option<&TypeTag>) -> bool { + pub fn accept( + &self, + is_req: bool, + msg: &TypeTag, + resp: Option<&TypeTag>, + err: Option<&TypeTag>, + ) -> bool { if is_req { self.inner.accept_req(msg, resp, err) } else { diff --git a/tests/test_backpressure.rs b/tests/test_backpressure.rs index 9cf1319..f79f4d7 100644 --- a/tests/test_backpressure.rs +++ b/tests/test_backpressure.rs @@ -52,7 +52,7 @@ async fn test_backpressure() { assert!(b.try_send(MsgF32(32f32)).is_err()); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/tests/test_batch.rs b/tests/test_batch.rs index 6d059de..ca3e426 100644 --- a/tests/test_batch.rs +++ b/tests/test_batch.rs @@ -98,7 +98,7 @@ async fn test_batch() { } println!("flush"); - b.flush().await; + b.flush_all().await; let mut lock = batches.lock(); lock.sort_by(|a, b| a[0].cmp(&b[0])); diff --git a/tests/test_concurrency.rs b/tests/test_concurrency.rs index 2c5ab3b..286d8b9 100644 --- a/tests/test_concurrency.rs +++ b/tests/test_concurrency.rs @@ -113,7 +113,7 @@ async fn test_sync() { println!("sent"); - b.flush().await; + b.flush_all().await; println!("flushed"); diff --git a/tests/test_relay.rs b/tests/test_relay.rs index 0f72f9d..354ab5f 100644 --- a/tests/test_relay.rs +++ b/tests/test_relay.rs @@ -2,7 +2,12 @@ use std::pin::Pin; use async_trait::async_trait; use futures::Stream; -use messagebus::{Action, AsyncHandler, Bus, Event, Message, MessageBounds, ReciveUntypedReceiver, SendUntypedReceiver, TypeTag, TypeTagAccept, TypeTagged, derive::{Error as MbError, Message}, error::{self, GenericError}, receivers}; +use messagebus::{ + derive::{Error as MbError, Message}, + error::{self, GenericError}, + receivers, Action, AsyncHandler, Bus, Event, Message, MessageBounds, ReciveUntypedReceiver, + SendUntypedReceiver, TypeTag, TypeTagAccept, TypeTagged, +}; use parking_lot::Mutex; use thiserror::Error; use tokio::sync::mpsc; @@ -60,32 +65,29 @@ impl TypeTagAccept for TestRelay { if msg.as_ref() == Msg::::type_tag_().as_ref() { if let Some(resp) = resp { if resp.as_ref() == Msg::::type_tag_().as_ref() { - return true + return true; } } else { - return true - } + return true; + } } if msg.as_ref() == Msg::::type_tag_().as_ref() { if let Some(resp) = resp { if resp.as_ref() == Msg::::type_tag_().as_ref() { - return true + return true; } } else { - return true - } + return true; + } } false } - fn accept_msg( - &self, - msg: &messagebus::TypeTag, - ) -> bool { + fn accept_msg(&self, msg: &messagebus::TypeTag) -> bool { if msg.as_ref() == Msg::::type_tag_().as_ref() { - return true + return true; } false @@ -94,8 +96,14 @@ impl TypeTagAccept for TestRelay { fn iter_types(&self) -> Box)>> { Box::new( std::iter::once((Msg::::type_tag_(), None)) - .chain(std::iter::once((Msg::::type_tag_(), Some((Msg::::type_tag_(), GenericError::type_tag_()))))) - .chain(std::iter::once((Msg::::type_tag_(), Some((Msg::::type_tag_(), GenericError::type_tag_()))))) + .chain(std::iter::once(( + Msg::::type_tag_(), + Some((Msg::::type_tag_(), GenericError::type_tag_())), + ))) + .chain(std::iter::once(( + Msg::::type_tag_(), + Some((Msg::::type_tag_(), GenericError::type_tag_())), + ))), ) } } @@ -190,7 +198,7 @@ async fn test_relay() { assert_eq!(res1.0, 9u8); assert_eq!(res2.0, 22u64); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/tests/test_req_resp.rs b/tests/test_req_resp.rs index f916e3c..dd19a1b 100644 --- a/tests/test_req_resp.rs +++ b/tests/test_req_resp.rs @@ -240,7 +240,7 @@ async fn test() { assert!((val - 1633.0f64).abs() < f64::EPSILON); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/tests/test_serde.rs b/tests/test_serde.rs index cf28ac7..0aab2b9 100644 --- a/tests/test_serde.rs +++ b/tests/test_serde.rs @@ -79,7 +79,7 @@ async fn test() { assert_eq!(val.type_tag(), TypeTag::from("MsgResponse")); assert_eq!(buff.as_slice(), br#"{"test1":24,"test2":"Hello, World!"}"#); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; } diff --git a/tests/test_shared.rs b/tests/test_shared.rs index a6c403f..7fabe67 100644 --- a/tests/test_shared.rs +++ b/tests/test_shared.rs @@ -82,7 +82,7 @@ async fn test_shared() { b.send_one(Msg).await.unwrap(); b.send_one(SharedMsg(0.0f32)).await.unwrap(); - b.flush().await; + b.flush_all().await; b.close().await; poller.await; diff --git a/tests/test_sync.rs b/tests/test_sync.rs index 9b4a5f2..e904fe1 100644 --- a/tests/test_sync.rs +++ b/tests/test_sync.rs @@ -79,7 +79,7 @@ async fn test_sync() { b.send(MsgU16(11u16)).await.unwrap(); b.send(MsgU32(32u32)).await.unwrap(); - b.flush_and_sync().await; + b.flush_and_sync_all().await; b.close().await; poller.await; }