diff --git a/crates/remote/src/proto.rs b/crates/remote/src/proto.rs index 67a7195..b38f325 100644 --- a/crates/remote/src/proto.rs +++ b/crates/remote/src/proto.rs @@ -169,7 +169,7 @@ impl ProtocolItem { ProtocolItem::Action(action) => match action { Action::Close => ProtocolHeaderActionKind::Close, Action::Flush => ProtocolHeaderActionKind::Flush, - Action::Init => ProtocolHeaderActionKind::Initialize, + Action::Init(..) => ProtocolHeaderActionKind::Initialize, Action::Sync => ProtocolHeaderActionKind::Synchronize, _ => unimplemented!(), } @@ -347,7 +347,7 @@ impl<'a> ProtocolPacket<'a> { ProtocolHeaderActionKind::Pause => Event::Pause, other => return Ok(ProtocolItem::Action(match other { - ProtocolHeaderActionKind::Initialize => Action::Init, + ProtocolHeaderActionKind::Initialize => Action::Init(self.header.argument), ProtocolHeaderActionKind::Close => Action::Close, ProtocolHeaderActionKind::Flush => Action::Flush, ProtocolHeaderActionKind::Synchronize => Action::Sync, diff --git a/crates/remote/src/relays/quic/client.rs b/crates/remote/src/relays/quic/client.rs index b184e63..33b16f5 100644 --- a/crates/remote/src/relays/quic/client.rs +++ b/crates/remote/src/relays/quic/client.rs @@ -1,9 +1,10 @@ use crate::{error::Error, proto::{BodyType, ProtocolItem, ProtocolPacket}, relays::{GenericEventStream, MessageTable}}; -use futures::{Future, FutureExt}; -use messagebus::{Action, Bus, Event, Message, ReciveUntypedReceiver, SendUntypedReceiver, TypeTag, TypeTagAccept}; +use futures::StreamExt; +use messagebus::{Action, Bus, Event, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept}; use parking_lot::Mutex; -use std::{net::SocketAddr, sync::atomic::AtomicBool}; -use tokio::{sync::{mpsc::{self, UnboundedSender, UnboundedReceiver}, oneshot}}; +use quinn::IncomingBiStreams; +use std::{net::SocketAddr, sync::{Arc, atomic::{AtomicU64, Ordering}}}; +use tokio::sync::mpsc::{self, UnboundedSender, UnboundedReceiver}; use bytes::{Buf, BufMut}; pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; @@ -30,64 +31,36 @@ impl QuicClientRelayEndpoint { Ok(Self { endpoint }) } - - pub fn connect( - &self, - addr: &SocketAddr, - host: &str, - ) -> impl Future> { - let conn = self.endpoint.connect(addr, host); - - async move { - let quinn::NewConnection { connection, .. } = conn?.await?; - let (send, recv) = connection.open_bi().await?; - - Ok(QuicClientConnection { - connection, - send, - recv, - }) - } - } - - #[inline] - pub async fn wait_idle(&self) { - self.endpoint.wait_idle().await; - } -} - -pub struct QuicClientConnection { - connection: quinn::Connection, - send: quinn::SendStream, - recv: quinn::RecvStream, } pub struct QuicClientRelay { - ready_flag: AtomicBool, + // ready_flag: AtomicBool, + self_id: Arc, addr: SocketAddr, host: String, endpoint: QuicClientRelayEndpoint, outgoing_table: MessageTable, sender: UnboundedSender, - receiver_send: Mutex, UnboundedReceiver)>>, - receiver_recv: Mutex>>, + receiver: Mutex, UnboundedSender)>>, + st_receiver: Mutex>>, } impl QuicClientRelay { pub fn new(cert: &str, addr: SocketAddr, host: String, table: Vec<(TypeTag, TypeTag, TypeTag)>) -> Result { let endpoint = QuicClientRelayEndpoint::new(cert)?; let (sender, receiver) = mpsc::unbounded_channel(); - let (recv_send, recv_recv) = oneshot::channel(); + let (st_sender, st_receiver) = mpsc::unbounded_channel(); Ok(Self { - ready_flag: AtomicBool::new(false), + // ready_flag: AtomicBool::new(false), + self_id: Arc::new(AtomicU64::new(0)), addr, host, endpoint, outgoing_table: MessageTable::from(table), sender, - receiver_send: Mutex::new(Some((recv_send, receiver))), - receiver_recv: Mutex::new(Some(recv_recv)), + receiver: Mutex::new(Some((receiver, st_sender))), + st_receiver: Mutex::new(Some(st_receiver)), }) } } @@ -111,38 +84,35 @@ impl TypeTagAccept for QuicClientRelay { impl SendUntypedReceiver for QuicClientRelay { fn send(&self, msg: Action, _bus: &Bus) -> Result<(), messagebus::error::Error> { match msg { - Action::Init => { - let (sender, mut rx) = self.receiver_send.lock().take().unwrap(); - let conn = self.endpoint.connect(&self.addr, &self.host); - + Action::Init(self_id) => { + let (mut rx, recv_stream) = self.receiver.lock().take().unwrap(); + let conn = self.endpoint.endpoint.connect(&self.addr, &self.host).unwrap(); + self.self_id.store(self_id, Ordering::SeqCst); + tokio::spawn(async move { - println!("spawn"); - let mut conn = conn.await.unwrap(); - sender.send((conn.recv, conn.connection)).unwrap(); let mut body_buff = Vec::new(); let mut header_buff = Vec::new(); + let conn = conn.await.unwrap(); + recv_stream.send(conn.bi_streams).unwrap(); + while let Some(r) = rx.recv().await { - body_buff.clear(); header_buff.clear(); - + body_buff.clear(); + + let (mut send, _) = conn.connection.open_bi().await.unwrap(); let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap(); + + header_buff.put(&b"MBUS"[..]); + header_buff.put_u16(1); + header_buff.put_u16(0); + header_buff.put_u64(header_buff.len() as _); serde_cbor::to_writer(&mut header_buff, &pkt).unwrap(); - println!("msg {:?}", pkt); + send.write_all(&header_buff).await.unwrap(); + send.finish().await.unwrap(); - let mut buf = [0u8; 16]; - let mut writer = &mut buf[..]; - - writer.put(&b"MBUS"[..]); - writer.put_u16(1); - writer.put_u16(0); - writer.put_u64(header_buff.len() as _); - - conn.send.write_all(&buf).await.unwrap(); - println!("header sent"); - conn.send.write_all(&header_buff).await.unwrap(); - println!("body sent"); + println!("sent"); } }); } @@ -178,54 +148,102 @@ impl ReciveUntypedReceiver for QuicClientRelay { type Stream = GenericEventStream; fn event_stream(&self, bus: Bus) -> Self::Stream { - let recv = self.receiver_recv.lock().take().unwrap(); + let self_id = self.self_id.clone(); + let sender = self.sender.clone(); + let mut recv = self.st_receiver.lock().take().unwrap(); - Box::pin(async move { - let buff = Vec::with_capacity(1024); - let (recv, conn) = recv.await.unwrap(); - futures::stream::unfold( - (true, recv, conn, bus, buff), - |(first, mut recv, conn, bus, mut buff)| async move { - if first { - return Some((Event::Ready, (false, recv, conn, bus, buff))); - } - - unsafe { buff.set_len(16) }; - recv.read_exact(&mut buff).await.unwrap(); + Box::pin( + futures::stream::poll_fn(move |cx|recv.poll_recv(cx)) + .map(move |uni_streams| { + let self_id = self_id.clone(); + let bus = bus.clone(); + let sender = sender.clone(); + // let buff = Bytes::new(); - let mut reader = &buff[..]; - let mut sign = [0u8; 4]; - reader.copy_to_slice(&mut sign); - assert!(&sign != b"MBUS"); - - let version = reader.get_u16(); - assert!(version == 1); - - let content_type = reader.get_u16(); - - let body_size = reader.get_u64(); - let diff = buff.capacity() as i64 - body_size as i64; - if diff < 0 { - buff.reserve(-diff as usize); - } - - unsafe { buff.set_len(body_size as usize); } - recv.read_exact(&mut buff).await.unwrap(); - - let event = match content_type { - 0 => { // CBOR - let proto: ProtocolPacket = serde_cbor::from_slice(&buff).unwrap(); - match proto.deserialize(&bus).unwrap() { - ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), - _ => unimplemented!() + futures::stream::unfold((true, uni_streams, bus, sender, self_id), |(first, mut uni_streams, bus, sender, self_id)| async move { + loop { + if first { + return Some((Event::Ready, (false, uni_streams, bus, sender, self_id))); } - }, - _ => unimplemented!() - }; - - Some((event, (false, recv, conn, bus, buff))) - }, - ) - }.flatten_stream()) + + let (_, recv) = match uni_streams.next().await? { + Ok(recv) => recv, + Err(err) => { + println!("error: {}", err); + return None; + } + }; + + let buff = recv + .read_to_end(usize::max_value()) + .await + .unwrap(); + + // assert_eq!(&buff[0..4], b"MBUS"); + + let mut reader = &buff[4..]; + + let version = reader.get_u16(); + let content_type = reader.get_u16(); + let body_size = reader.get_u64(); + + println!("inbound packet {}: v: {}; ct: {}; bs: {}", String::from_utf8_lossy(&buff[0..4]), version, content_type, body_size); + + let event = match content_type { + 0 => { // CBOR + let proto: ProtocolPacket = serde_cbor::from_slice(&buff[16..]).unwrap(); + + match proto.deserialize(&bus).unwrap() { + ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), + ProtocolItem::Action(action) => { + match action { + Action::Close => { + println!("warning: Close recevied - ignoring!"); + sender.send(ProtocolItem::Event(Event::Exited)).unwrap(); + }, + Action::Flush => { + bus.flush().await; + sender.send(ProtocolItem::Event(Event::Flushed)).unwrap(); + }, + Action::Sync => { + bus.sync().await; + sender.send(ProtocolItem::Event(Event::Synchronized(Ok(())))).unwrap(); + }, + Action::Init(..) => (), + Action::Stats => (), + _ => (), + } + continue; + } + ProtocolItem::Send(mid, msg, req) => { + if req { + let res = bus.request_boxed( + msg.upcast_box(), + SendOptions::Except(self_id.load(Ordering::SeqCst)) + ) + .await + .map(|x|x.as_shared_boxed().unwrap()) + .map_err(|x|x.map_msg(|_|())); + + sender.send(ProtocolItem::Event(Event::Response(mid, res))).unwrap(); + } else { + let _ = bus.send_boxed(msg.upcast_box(), Default::default()) + .await; + } + + continue; + } + _ => unimplemented!() + } + }, + _ => unimplemented!() + }; + + return Some((event, (false, uni_streams, bus, sender, self_id))); + } + }) + }) + .flatten() + ) } } diff --git a/crates/remote/src/relays/quic/mod.rs b/crates/remote/src/relays/quic/mod.rs index 6f99013..50e6a94 100644 --- a/crates/remote/src/relays/quic/mod.rs +++ b/crates/remote/src/relays/quic/mod.rs @@ -2,7 +2,88 @@ mod client; mod server; pub use client::QuicClientRelay; +use messagebus::{Bus, ReciveUntypedReceiver}; +use parking_lot::Mutex; +use quinn::IncomingUniStreams; pub use server::QuicServerRelay; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use futures::StreamExt; +use bytes::{Buf}; +use crate::proto::{ProtocolItem, ProtocolPacket}; +use super::GenericEventStream; +pub struct QuicRelay { + sender: UnboundedSender, + receiver: Mutex, UnboundedSender)>>, + st_receiver: Mutex>>, +} + +pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; + +impl ReciveUntypedReceiver for QuicRelay { + type Stream = GenericEventStream; + + fn event_stream(&self, bus: Bus) -> Self::Stream { + let mut recv = self.st_receiver.lock().take().unwrap(); + + Box::pin( + futures::stream::poll_fn(move |cx|recv.poll_recv(cx)) + .map(move |uni_streams| { + let bus = bus.clone(); + uni_streams.filter_map(move |recv| { + + let bus = bus.clone(); + let mut buff = Vec::new(); + + async move { + let mut recv = recv.ok()?; + + println!("1"); + + unsafe { buff.set_len(16) }; + recv.read_exact(&mut buff).await.unwrap(); + + println!("{:?}", buff); + + let mut reader = &buff[..]; + let mut sign = [0u8; 4]; + reader.copy_to_slice(&mut sign); + assert!(&sign != b"MBUS"); + + let version = reader.get_u16(); + assert!(version == 1); + + let content_type = reader.get_u16(); + + let body_size = reader.get_u64(); + let diff = buff.capacity() as i64 - body_size as i64; + if diff < 0 { + buff.reserve(-diff as usize); + } + + unsafe { buff.set_len(body_size as usize); } + recv.read_exact(&mut buff).await.unwrap(); + + println!("{:?}", buff); + + let event = match content_type { + 0 => { // CBOR + let proto: ProtocolPacket = serde_cbor::from_slice(&buff).unwrap(); + + match proto.deserialize(&bus).unwrap() { + ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), + _ => unimplemented!() + } + }, + _ => unimplemented!() + }; + Some(event) + } + }) + }) + .flatten() + ) + } +} \ No newline at end of file diff --git a/crates/remote/src/relays/quic/server.rs b/crates/remote/src/relays/quic/server.rs index 2172e74..4eabe19 100644 --- a/crates/remote/src/relays/quic/server.rs +++ b/crates/remote/src/relays/quic/server.rs @@ -1,15 +1,16 @@ use crate::{error::Error, proto::{BodyType, ProtocolItem, ProtocolPacket}, relays::{GenericEventStream, MessageTable}}; -use futures::{Future, FutureExt, StreamExt}; -use messagebus::{Action, Bus, Event, Message, ReciveUntypedReceiver, SendUntypedReceiver, TypeTag, TypeTagAccept}; +use futures::StreamExt; +use messagebus::{Action, Bus, Event, Message, ReciveUntypedReceiver, SendOptions, SendUntypedReceiver, TypeTag, TypeTagAccept}; use parking_lot::Mutex; -use std::{net::SocketAddr, sync::{Arc, atomic::AtomicBool}}; -use tokio::{sync::{mpsc::{self, UnboundedSender, UnboundedReceiver}, oneshot}}; +use quinn::IncomingBiStreams; +use std::{net::SocketAddr, sync::{Arc, atomic::{AtomicU64, Ordering}}}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use bytes::{Buf, BufMut}; pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; pub struct QuicServerRelayEndpoint { - endpoint: quinn::Endpoint, + // endpoint: Mutex>, incoming: Mutex>, } @@ -39,72 +40,39 @@ impl QuicServerRelayEndpoint { let mut endpoint = quinn::Endpoint::builder(); endpoint.listen(server_config.build()); - let (endpoint, incoming) = endpoint.bind(addr)?; + let (_, incoming) = endpoint.bind(addr)?; - Ok(Self { endpoint, incoming: Mutex::new(Some(incoming)) }) + Ok(Self { + // endpoint: Mutex::new(Some(endpoint)), + incoming: Mutex::new(Some(incoming)) + }) } - - pub fn incoming(&self) -> impl Future> { - let mut conn = self.incoming.lock().take().unwrap(); - - async move { - let conn = conn.next().await.unwrap(); - let quinn::NewConnection { connection, .. } = conn.await?; - - Ok(QuicServerConnection { - connection, - }) - } - } - - pub fn connect( - &self, - addr: &SocketAddr, - host: &str, - ) -> impl Future> { - let conn = self.endpoint.connect(addr, host); - - async move { - let quinn::NewConnection { connection, .. } = conn?.await?; - - Ok(QuicServerConnection { - connection, - }) - } - } - - #[inline] - pub async fn wait_idle(&self) { - self.endpoint.wait_idle().await; - } -} - -pub struct QuicServerConnection { - connection: quinn::Connection, } pub struct QuicServerRelay { - ready_flag: AtomicBool, + // ready_flag: AtomicBool, + self_id: Arc, endpoint: QuicServerRelayEndpoint, outgoing_table: MessageTable, sender: UnboundedSender, - receiver_send: Mutex, UnboundedReceiver)>>, - receiver_recv: Mutex>>, + receiver: Mutex, UnboundedSender)>>, + st_receiver: Mutex>>, } impl QuicServerRelay { pub fn new(key_path: &str, cert_path: &str, addr: SocketAddr, table: Vec<(TypeTag, TypeTag, TypeTag)>) -> Result { let endpoint = QuicServerRelayEndpoint::new(key_path, cert_path, &addr)?; let (sender, receiver) = mpsc::unbounded_channel(); - let (recv_send, recv_recv) = mpsc::unbounded_channel(); + let (st_sender, st_receiver) = mpsc::unbounded_channel(); Ok(Self { - ready_flag: AtomicBool::new(false), + // ready_flag: AtomicBool::new(false), + self_id: Arc::new(AtomicU64::new(0)), endpoint, outgoing_table: MessageTable::from(table), sender, - receiver_send: Mutex::new(Some((recv_send, receiver))), - receiver_recv: Mutex::new(Some(recv_recv)), + receiver: Mutex::new(Some((receiver, st_sender))), + st_receiver: Mutex::new(Some(st_receiver)), }) } } @@ -128,37 +96,43 @@ impl TypeTagAccept for QuicServerRelay { impl SendUntypedReceiver for QuicServerRelay { fn send(&self, msg: Action, _bus: &Bus) -> Result<(), messagebus::error::Error> { match msg { - Action::Init => { - let (sender, mut rx) = self.receiver_send.lock().take().unwrap(); - let conn = self.endpoint.incoming(); + Action::Init(self_id) => { + let (mut rx, recv_stream) = self.receiver.lock().take().unwrap(); + let mut incoming = self.endpoint.incoming.lock().take().unwrap(); + self.self_id.store(self_id, Ordering::SeqCst); tokio::spawn(async move { - let conn = conn.await.unwrap(); - let mut body_buff = Vec::new(); let mut header_buff = Vec::new(); + + loop { + let conn = match incoming.next().await { + Some(x) => x, + None => todo!("message lost!!!") + }; - while let Some(r) = rx.recv().await { - let (mut send, recv) = conn.connection.open_bi().await.unwrap(); + println!("new connection"); - sender.send(recv).unwrap(); + let conn = conn.await.unwrap(); - body_buff.clear(); - header_buff.clear(); - - let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap(); - serde_cbor::to_writer(&mut header_buff, &pkt).unwrap(); + recv_stream.send(conn.bi_streams).unwrap(); - let mut buf = [0u8; 16]; - let mut writer = &mut buf[..]; + while let Some(r) = rx.recv().await { + header_buff.clear(); + body_buff.clear(); - writer.put(&b"MBUS"[..]); - writer.put_u16(1); - writer.put_u16(0); - writer.put_u64(header_buff.len() as _); + let (mut send, _) = conn.connection.open_bi().await.unwrap(); + let pkt = r.serialize(BodyType::Cbor, &mut body_buff).unwrap(); - send.write_all(&buf).await.unwrap(); - send.write_all(&header_buff).await.unwrap(); + header_buff.put(&b"MBUS"[..]); + header_buff.put_u16(1); + header_buff.put_u16(0); + header_buff.put_u64(header_buff.len() as _); + serde_cbor::to_writer(&mut header_buff, &pkt).unwrap(); + + send.write_all(&header_buff).await.unwrap(); + send.finish().await.unwrap(); + } } }); } @@ -194,65 +168,101 @@ impl ReciveUntypedReceiver for QuicServerRelay { type Stream = GenericEventStream; fn event_stream(&self, bus: Bus) -> Self::Stream { - let mut recv = self.receiver_recv.lock().take().unwrap(); + let self_id = self.self_id.clone(); + let sender = self.sender.clone(); + let mut recv = self.st_receiver.lock().take().unwrap(); - Box::pin(futures::stream::poll_fn(move |cx|recv.poll_recv(cx)) - .map(move |recv| { - let buff = Vec::with_capacity(16); - let bus = bus.clone(); + Box::pin( + futures::stream::poll_fn(move |cx|recv.poll_recv(cx)) + .map(move |uni_streams| { + let self_id = self_id.clone(); + let bus = bus.clone(); + let sender = sender.clone(); + // let buff = Bytes::new(); - futures::stream::unfold( - (true, recv, bus, buff), - |(first, mut recv, bus, mut buff)| async move { - if first { - return Some((Event::Ready, (false, recv, bus, buff))); - } + futures::stream::unfold((true, uni_streams, bus, sender, self_id), |(first, mut uni_streams, bus, sender, self_id)| async move { + loop { + if first { + return Some((Event::Ready, (false, uni_streams, bus, sender, self_id))); + } - println!("1"); - - unsafe { buff.set_len(16) }; - recv.read_exact(&mut buff).await.unwrap(); - - println!("{:?}", buff); - - let mut reader = &buff[..]; - let mut sign = [0u8; 4]; - reader.copy_to_slice(&mut sign); - assert!(&sign != b"MBUS"); - - let version = reader.get_u16(); - assert!(version == 1); - - let content_type = reader.get_u16(); - - let body_size = reader.get_u64(); - let diff = buff.capacity() as i64 - body_size as i64; - if diff < 0 { - buff.reserve(-diff as usize); - } - - unsafe { buff.set_len(body_size as usize); } - recv.read_exact(&mut buff).await.unwrap(); - - println!("{:?}", buff); - - let event = match content_type { - 0 => { // CBOR - let proto: ProtocolPacket = serde_cbor::from_slice(&buff).unwrap(); - - match proto.deserialize(&bus).unwrap() { - ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), - _ => unimplemented!() + let (_, recv) = match uni_streams.next().await? { + Ok(recv) => recv, + Err(err) => { + println!("error: {}", err); + return None; } - }, - _ => unimplemented!() - }; - - Some((event, (false, recv, bus, buff))) - }, - ) - }) - .flatten() + }; + + let buff = recv + .read_to_end(usize::max_value()) + .await + .unwrap(); + + // assert_eq!(&buff[0..4], b"MBUS"); + + let mut reader = &buff[4..]; + + let version = reader.get_u16(); + let content_type = reader.get_u16(); + let body_size = reader.get_u64(); + + println!("inbound packet {}: v: {}; ct: {}; bs: {}", String::from_utf8_lossy(&buff[0..4]), version, content_type, body_size); + + let event = match content_type { + 0 => { // CBOR + let proto: ProtocolPacket = serde_cbor::from_slice(&buff[16..]).unwrap(); + match proto.deserialize(&bus).unwrap() { + ProtocolItem::Event(ev) => ev.map_msg(|msg|msg.upcast_box()), + ProtocolItem::Action(action) => { + match action { + Action::Close => { + println!("warning: Close recevied - ignoring!"); + sender.send(ProtocolItem::Event(Event::Exited)).unwrap(); + }, + Action::Flush => { + bus.flush().await; + sender.send(ProtocolItem::Event(Event::Flushed)).unwrap(); + }, + Action::Sync => { + bus.sync().await; + sender.send(ProtocolItem::Event(Event::Synchronized(Ok(())))).unwrap(); + }, + Action::Init(..) => (), + Action::Stats => (), + _ => (), + } + continue; + } + ProtocolItem::Send(mid, msg, req) => { + if req { + let res = bus.request_boxed( + msg.upcast_box(), + SendOptions::Except(self_id.load(Ordering::SeqCst)) + ) + .await + .map(|x|x.as_shared_boxed().unwrap()) + .map_err(|x|x.map_msg(|_|())); + + sender.send(ProtocolItem::Event(Event::Response(mid, res))).unwrap(); + } else { + let _ = bus.send_boxed(msg.upcast_box(), Default::default()) + .await; + } + + continue; + } + _ => unimplemented!() + } + }, + _ => unimplemented!() + }; + + return Some((event, (false, uni_streams, bus, sender, self_id))); + } + }) + }) + .flatten() ) } } \ No newline at end of file diff --git a/src/builder.rs b/src/builder.rs index 3af2bed..2d314ca 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,6 +1,6 @@ use core::{marker::PhantomData, pin::Pin}; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::{Arc, atomic::{AtomicU64, Ordering}}}; use futures::{Future, FutureExt}; use smallvec::SmallVec; @@ -17,6 +17,8 @@ type MessageDeserializerCallback = Box< + Sync, >; +static RECEVIER_ID_SEQ: AtomicU64 = AtomicU64::new(1); + pub trait ReceiverSubscriberBuilder: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver where @@ -75,7 +77,7 @@ impl RegisterEntry { { let (inner, poller) = S::build(cfg); - let receiver = Receiver::new::(queue, inner); + let receiver = Receiver::new::(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), queue, inner); let poller2 = receiver.start_polling(); self.receivers.insert(M::type_tag_(), receiver); self.pollers.push(poller(self.item.clone())); @@ -144,7 +146,7 @@ impl RegisterEntry { { let (inner, poller) = S::build(cfg); - let receiver = Receiver::new::(queue, inner); + let receiver = Receiver::new::(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), queue, inner); let poller2 = receiver.start_polling(); self.receivers.insert(M::type_tag_(), receiver); self.pollers.push(poller(self.item.clone())); @@ -250,7 +252,7 @@ impl Module { } pub fn register_relay(mut self, inner: S) -> Self { - let receiver = Receiver::new_relay::(inner); + let receiver = Receiver::new_relay::(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), inner); self.pollings.push(receiver.start_polling()); let mut receiver_added = false; diff --git a/src/lib.rs b/src/lib.rs index 353dd45..0642540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ static ID_COUNTER: AtomicU64 = AtomicU64::new(1); #[derive(Debug, Clone, Copy, PartialEq)] pub enum SendOptions { Broadcast, + Except(u64), Direct(u64), Random, Balanced, @@ -159,8 +160,7 @@ impl Bus { } } - pub async fn flush_and_sync(&self) { - self.flush().await; + pub async fn sync(&self) { let _handle = self.inner.maintain.lock().await; for rs in self.inner.receivers.values() { @@ -170,6 +170,12 @@ impl Bus { } } + #[inline] + pub async fn flush_and_sync(&self) { + self.flush().await; + self.sync().await; + } + fn try_reserve(&self, tt: &TypeTag, rs: &[Receiver]) -> Option> { let mut permits = SmallVec::<[Permit; 32]>::new(); @@ -421,7 +427,7 @@ impl Bus { pub async fn send_boxed( &self, msg: Box, - _options: SendOptions, + options: SendOptions, ) -> Result<(), Error>> { if self.inner.closed.load(Ordering::SeqCst) { return Err(SendError::Closed(msg).into()); @@ -430,25 +436,30 @@ impl Bus { let tt = msg.type_tag(); let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - if let Some(rs) = self.inner.receivers.get(&tt) { - if let Some((last, head)) = rs.split_last() { - for r in head { - let _ = r.send_boxed( - self, - mid, - msg.try_clone_boxed().unwrap(), - false, - r.reserve(&tt).await, - ); - } + let mut iter = self.select_receivers(&tt, options, None, None); + let first = iter.next(); - let _ = last.send_boxed(self, mid, msg, false, last.reserve(&tt).await); - - return Ok(()); - } + for r in iter { + let _ = r.send_boxed( + self, + mid, + msg.try_clone_boxed().unwrap(), + false, + r.reserve(&tt).await, + ); } - warn!("Unhandled message: no receivers"); + if let Some(r) = first { + let _ = r.send_boxed( + self, + mid, + msg.try_clone_boxed().unwrap(), + false, + r.reserve(&tt).await, + ); + } else { + warn!("Unhandled message: no receivers"); + } Ok(()) } @@ -456,7 +467,7 @@ impl Bus { pub async fn send_boxed_one( &self, msg: Box, - _options: SendOptions, + options: SendOptions, ) -> Result<(), Error>> { if self.inner.closed.load(Ordering::SeqCst) { return Err(SendError::Closed(msg).into()); @@ -465,7 +476,8 @@ impl Bus { let tt = msg.type_tag(); let mid = ID_COUNTER.fetch_add(1, Ordering::Relaxed); - if let Some(rs) = self.inner.receivers.get(&tt).and_then(|rs| rs.first()) { + let mut iter = self.select_receivers(&tt, options, None, None); + if let Some(rs) = iter.next() { Ok(rs.send_boxed(self, mid, msg, false, rs.reserve(&tt).await)?) } else { Err(Error::NoReceivers) @@ -616,7 +628,7 @@ impl Bus { fn select_receivers<'a, 'b: 'a, 'c: 'a, 'd: 'a>( &'a self, tid: &'b TypeTag, - _options: SendOptions, + options: SendOptions, rid: Option<&'c TypeTag>, eid: Option<&'d TypeTag>, ) -> impl Iterator + 'a { @@ -626,6 +638,11 @@ impl Bus { .into_iter() .map(|item| item.iter()) .flatten() - .filter(move |x| x.accept(tid, rid, eid)) + .filter(move |r| r.accept(tid, rid, eid)) + .filter(move |r| match options { + SendOptions::Except(id) => id != r.id(), + SendOptions::Direct(id) => id == r.id(), + _ => true + }) } } diff --git a/src/receiver.rs b/src/receiver.rs index 0276267..4786030 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -96,6 +96,8 @@ pub trait ReceiverTrait: TypeTagAccept + Send + Sync { fn typed(&self) -> Option>; fn wrapper(&self) -> Option>; + fn id(&self) -> u64; + fn send_boxed( &self, mid: u64, @@ -139,7 +141,7 @@ pub trait PermitDrop { #[non_exhaustive] #[derive(Debug, Clone)] pub enum Action { - Init, + Init(u64), Flush, Sync, Close, @@ -186,6 +188,7 @@ where S: ReciveTypedReceiver + 'static, { inner: S, + id: u64, waiters: Slab>, context: Arc, _m: PhantomData<(M, R, E)>, @@ -363,6 +366,10 @@ where E: StdSyncSendError, S: SendUntypedReceiver + SendTypedReceiver + ReciveTypedReceiver + 'static, { + fn id(&self) -> u64 { + self.id + } + fn name(&self) -> &str { std::any::type_name::() } @@ -708,7 +715,7 @@ impl core::cmp::Eq for Receiver {} impl Receiver { #[inline] - pub(crate) fn new(limit: u64, inner: S) -> Self + pub(crate) fn new(id: u64, limit: u64, inner: S) -> Self where M: Message, R: Message, @@ -717,6 +724,7 @@ impl Receiver { { Self { inner: Arc::new(ReceiverWrapper { + id, inner, waiters: sharded_slab::Slab::new_with_config::(), context: Arc::new(ReceiverContext { @@ -737,15 +745,20 @@ impl Receiver { } #[inline] - pub(crate) fn new_relay(inner: S) -> Self + pub(crate) fn new_relay(id: u64, inner: S) -> Self where S: Relay + Send + Sync + 'static, { Self { - inner: Arc::new(RelayWrapper::new(inner)), + inner: Arc::new(RelayWrapper::new(id, inner)), } } + #[inline] + pub fn id(&self) -> u64 { + self.inner.id() + } + #[inline] pub fn name(&self) -> &str { self.inner.name() @@ -964,7 +977,7 @@ impl Receiver { #[inline] pub fn init(&self, bus: &Bus) -> Result<(), Error> { if !self.inner.is_init_sent() { - self.inner.send_action(bus, Action::Init) + self.inner.send_action(bus, Action::Init(self.inner.id())) } else { Ok(()) } diff --git a/src/receivers/buffer_unordered/mod.rs b/src/receivers/buffer_unordered/mod.rs index 9b35051..cb3741a 100644 --- a/src/receivers/buffer_unordered/mod.rs +++ b/src/receivers/buffer_unordered/mod.rs @@ -62,7 +62,7 @@ macro_rules! buffer_unordered_poller_macro { semaphore.clone().acquire_owned().await, ); } - Request::Action(Action::Init) => { + Request::Action(Action::Init(..)) => { stx.send(Event::Ready).unwrap(); } Request::Action(Action::Close) => { diff --git a/src/receivers/buffer_unordered_batched/mod.rs b/src/receivers/buffer_unordered_batched/mod.rs index bbe24f9..881c22c 100644 --- a/src/receivers/buffer_unordered_batched/mod.rs +++ b/src/receivers/buffer_unordered_batched/mod.rs @@ -79,7 +79,7 @@ macro_rules! buffer_unordered_batch_poller_macro { ($st1)(buffer_mid_clone, buffer_clone, bus, ut, task_permit, stx); } } - Request::Action(Action::Init) => { + Request::Action(Action::Init(..)) => { stx.send(Event::Ready).unwrap(); } Request::Action(Action::Close) => { diff --git a/src/receivers/synchronize_batched/mod.rs b/src/receivers/synchronize_batched/mod.rs index 7ce55b8..e488c03 100644 --- a/src/receivers/synchronize_batched/mod.rs +++ b/src/receivers/synchronize_batched/mod.rs @@ -71,7 +71,7 @@ macro_rules! batch_synchronized_poller_macro { let _ = ($st1)(buffer_mid_clone, buffer_clone, bus, ut, stx); } } - Request::Action(Action::Init) => { + Request::Action(Action::Init(..)) => { stx.send(Event::Ready).unwrap(); } Request::Action(Action::Close) => { diff --git a/src/receivers/synchronized/mod.rs b/src/receivers/synchronized/mod.rs index f2ec28a..b8835e6 100644 --- a/src/receivers/synchronized/mod.rs +++ b/src/receivers/synchronized/mod.rs @@ -49,7 +49,7 @@ macro_rules! synchronized_poller_macro { .await .unwrap() } - Request::Action(Action::Init) => { + Request::Action(Action::Init(..)) => { stx.send(Event::Ready).unwrap(); } Request::Action(Action::Close) => { diff --git a/src/relay.rs b/src/relay.rs index 85edba4..1403848 100644 --- a/src/relay.rs +++ b/src/relay.rs @@ -60,13 +60,15 @@ pub(crate) struct RelayWrapper where S: 'static, { + id: u64, inner: S, context: Arc, waiters: Slab, Error>>>, } impl RelayWrapper { - pub fn new(inner: S) -> Self { + pub fn new(id: u64, inner: S) -> Self { Self { + id, inner, context: Arc::new(RelayContext { receivers: DashMap::new(), @@ -100,6 +102,10 @@ impl ReceiverTrait for RelayWrapper where S: Relay + Send + Sync + 'static, { + fn id(&self) -> u64 { + self.id + } + fn name(&self) -> &str { std::any::type_name::() } diff --git a/tests/test_relay.rs b/tests/test_relay.rs index ad819b9..40350f7 100644 --- a/tests/test_relay.rs +++ b/tests/test_relay.rs @@ -116,7 +116,7 @@ impl TypeTagAccept for TestRelay { impl SendUntypedReceiver for TestRelay { fn send(&self, msg: Action, _bus: &Bus) -> Result<(), error::Error> { match msg { - Action::Init => { + Action::Init(..) => { self.stx.send(Event::Ready).unwrap(); } Action::Close => {