Tcp relay update; refactoring type tag

This commit is contained in:
Andrey Tkachenko 2021-10-15 19:31:33 +04:00
parent bc862cfac6
commit 18bd3c63f4
16 changed files with 214 additions and 106 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "messagebus"
version = "0.9.9"
version = "0.9.10"
authors = ["Andrey Tkachenko <andrey@aidev.ru>"]
repository = "https://github.com/andreytkachenko/messagebus.git"
keywords = ["futures", "async", "tokio", "message", "bus"]
@ -17,7 +17,8 @@ members = [
]
[dependencies]
messagebus_derive = "0.2.4"
messagebus_derive = "0.2.5"
tokio = { version = "1", features = ["parking_lot", "rt-multi-thread", "sync", "time"] }
parking_lot = "0.11"
async-trait = "0.1"
@ -30,6 +31,7 @@ erased-serde = "0.3"
serde = "1"
serde_derive = "1"
dashmap = "4.0"
ctor = "0.1.21"
[dev-dependencies]
anyhow = "1.0"

View File

@ -1,6 +1,6 @@
[package]
name = "messagebus_derive"
version = "0.2.4"
version = "0.2.5"
authors = ["Andrey Tkachenko <andrey@aidev.ru>"]
repository = "https://github.com/andreytkachenko/messagebus.git"
keywords = ["futures", "async", "tokio", "message", "bus"]

View File

@ -2,9 +2,12 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
// use proc_macro::{TokenStream};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{ToTokens, quote};
use std::collections::hash_map;
use std::fmt::Write;
use std::hash::Hasher;
use syn::parse::{Parse, ParseStream};
use syn::{parenthesized, Result};
use syn::{punctuated::Punctuated, token::Comma, DeriveInput};
@ -126,6 +129,7 @@ fn type_tag_part(
fn type_tag_() -> messagebus::TypeTag { format!(#type_name, #type_values).into() }
fn type_tag(&self) -> messagebus::TypeTag { Self::type_tag_() }
fn type_name(&self) -> std::borrow::Cow<str> { Self::type_tag_() }
fn type_layout(&self) -> std::alloc::Layout { std::alloc::Layout::for_value(self) }
}
}
} else {
@ -134,6 +138,7 @@ fn type_tag_part(
fn type_tag_() -> messagebus::TypeTag { #type_name.into() }
fn type_tag(&self) -> messagebus::TypeTag { Self::type_tag_() }
fn type_name(&self) -> std::borrow::Cow<str> { Self::type_tag_() }
fn type_layout(&self) -> std::alloc::Layout { std::alloc::Layout::for_value(self) }
}
}
}
@ -196,7 +201,7 @@ impl Parse for Tags {
}
#[proc_macro_derive(Message, attributes(type_tag, message, namespace))]
pub fn derive_message(input: TokenStream) -> TokenStream {
pub fn derive_message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut tags = Tags::default();
let mut type_tag = None;
let mut namespace = None;
@ -231,12 +236,32 @@ pub fn derive_message(input: TokenStream) -> TokenStream {
for mut param in impl_generics.params.pairs_mut() {
match &mut param.value_mut() {
syn::GenericParam::Lifetime(_) => {}
syn::GenericParam::Type(param) => {
syn::GenericParam::Type(params) => {
let bound: syn::TypeParamBound =
syn::parse_str("messagebus::MessageBounds").unwrap();
param.bounds.push(bound);
params.bounds.push(bound);
if tags.has_shared {
let bound: syn::TypeParamBound =
syn::parse_str("messagebus::__reexport::serde::Serialize").unwrap();
params.bounds.push(bound);
let bound: syn::TypeParamBound =
syn::parse_str("messagebus::__reexport::serde::Deserialize<'de>").unwrap();
params.bounds.push(bound);
}
if tags.has_clone {
let bound: syn::TypeParamBound =
syn::parse_str("core::clone::Clone").unwrap();
params.bounds.push(bound);
}
}
syn::GenericParam::Const(_param) => {}
syn::GenericParam::Const(_) => {}
}
}
@ -244,6 +269,23 @@ pub fn derive_message(input: TokenStream) -> TokenStream {
let shared_part = shared_part(&ast, tags.has_shared);
let clone_part = clone_part(&ast, tags.has_clone);
let init = Ident::new(&format!("__init_{}", hash(ast.clone().into_token_stream())), Span::call_site());
let init_impl = if tags.has_shared && impl_generics.params.is_empty() {
quote! {
#[allow(non_upper_case_globals)]
#[messagebus::ctor::ctor]
fn #init() {
messagebus::register_shared_message::<#name>();
}
}
} else {
quote!{}
};
if !impl_generics.params.is_empty() && tags.has_shared {
impl_generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(syn::Lifetime::new("'de", Span::call_site()))))
}
let tokens = quote! {
#type_tag_part
@ -256,13 +298,15 @@ pub fn derive_message(input: TokenStream) -> TokenStream {
#shared_part
#clone_part
}
#init_impl
};
tokens.into()
}
#[proc_macro_derive(Error, attributes(type_tag, namespace))]
pub fn derive_error(input: TokenStream) -> TokenStream {
pub fn derive_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut type_tag = None;
let mut namespace = None;
@ -292,3 +336,9 @@ pub fn derive_error(input: TokenStream) -> TokenStream {
tokens.into()
}
fn hash(input: TokenStream) -> u64 {
let mut hasher = hash_map::DefaultHasher::new();
hasher.write(input.to_string().as_bytes());
hasher.finish()
}

View File

@ -33,8 +33,6 @@ async fn main() {
).unwrap();
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.build();

View File

@ -70,8 +70,6 @@ async fn main() {
).unwrap();
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.register(TmpReceiver)
.subscribe_async::<Req>(8, Default::default())

View File

@ -30,8 +30,6 @@ async fn main() {
);
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.build();

View File

@ -67,8 +67,6 @@ async fn main() {
);
let (b, poller) = Bus::build()
.register_shared_message::<Req>()
.register_shared_message::<Resp>()
.register_relay(relay)
.register(TmpReceiver)
.subscribe_async::<Req>(8, Default::default())

View File

@ -288,7 +288,7 @@ pub struct ProtocolPacket<'a> {
impl<'a> ProtocolPacket<'a> {
pub fn deserialize(
self,
bus: &Bus,
_bus: &Bus,
) -> Result<ProtocolItem, crate::error::Error>
{
let type_tag: Option<TypeTag> = if self.header.flags.contains(ProtocolHeaderFlags::TYPE_TAG) {
@ -309,7 +309,7 @@ impl<'a> ProtocolPacket<'a> {
} else if self.header.flags.contains(ProtocolHeaderFlags::TT_AND_BODY) {
let body = self.body.ok_or_else(|| crate::error::Error::ProtocolParseError("No body".to_string()))?;
let res = generic_deserialize(self.header.body_type, body.as_ref(), |de| {
bus.deserialize_message(type_tag.clone().unwrap(), de)
messagebus::deserialize_shared_message(type_tag.clone().unwrap(), de)
.map_err(|x| x.map_msg(|_| ()))
})?;
@ -436,7 +436,6 @@ mod tests {
#[test]
fn test_proto_pack_event() {
let (bus, _) = Bus::build()
.register_shared_message::<TestSharedMessage>()
.build();
let pkt = ProtocolPacket {
@ -476,7 +475,6 @@ mod tests {
#[test]
fn test_proto_pack_event_error() {
let (bus, _) = Bus::build()
.register_shared_message::<TestSharedMessage>()
.build();
let pkt = ProtocolPacket {

View File

@ -1,21 +1,15 @@
use core::{marker::PhantomData, pin::Pin};
use std::{collections::{HashMap, HashSet}, sync::{Arc, atomic::{AtomicU64, Ordering}}};
use std::{collections::HashSet, sync::{Arc, atomic::{AtomicU64, Ordering}}};
use futures::{Future, FutureExt};
use tokio::sync::Mutex;
use crate::{AsyncBatchHandler, AsyncBatchSynchronizedHandler, AsyncHandler, AsyncSynchronizedHandler, BatchHandler, BatchSynchronizedHandler, Bus, BusInner, Handler, Message, Relay, SharedMessage, SynchronizedHandler, TypeTag, Untyped, envelop::IntoSharedMessage, error::{Error, StdSyncSendError}, receiver::{
use crate::{AsyncBatchHandler, AsyncBatchSynchronizedHandler, AsyncHandler, AsyncSynchronizedHandler, BatchHandler, BatchSynchronizedHandler, Bus, BusInner, Handler, Message, Relay, SynchronizedHandler, Untyped, error::StdSyncSendError, receiver::{
BusPollerCallback, Receiver, ReciveTypedReceiver, SendTypedReceiver, SendUntypedReceiver,
UntypedPollerCallback,
}, receivers};
type MessageDeserializerCallback = Box<
dyn Fn(&mut dyn erased_serde::Deserializer<'_>) -> Result<Box<dyn SharedMessage>, Error>
+ Send
+ Sync,
>;
static RECEVIER_ID_SEQ: AtomicU64 = AtomicU64::new(1);
pub trait ReceiverSubscriberBuilder<T, M, R, E>:
@ -204,23 +198,8 @@ impl<T, F, P, B> RegisterEntry<SyncEntry, T, F, P, B> {
}
}
pub struct MessageTypeDescriptor {
de: MessageDeserializerCallback,
}
impl MessageTypeDescriptor {
#[inline]
pub fn deserialize_boxed(
&self,
de: &mut dyn erased_serde::Deserializer<'_>,
) -> Result<Box<dyn SharedMessage>, Error> {
(self.de)(de)
}
}
#[derive(Default)]
pub struct Module {
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
receivers: HashSet<Receiver>,
pollings: Vec<BusPollerCallback>,
}
@ -228,28 +207,11 @@ pub struct Module {
impl Module {
pub fn new() -> Self {
Self {
message_types: HashMap::new(),
receivers: HashSet::new(),
pollings: Vec::new(),
}
}
pub fn register_shared_message<
M: Message + Clone + serde::Serialize + serde::de::DeserializeOwned,
>(
mut self,
) -> Self {
println!("insert {}", M::type_tag_());
self.message_types.insert(
M::type_tag_(),
MessageTypeDescriptor {
de: Box::new(move |de| Ok(M::deserialize(de)?.into_shared())),
},
);
self
}
pub fn register_relay<S: Relay + Send + Sync + 'static>(mut self, inner: S) -> Self {
let receiver = Receiver::new_relay::<S>(RECEVIER_ID_SEQ.fetch_add(1, Ordering::Relaxed), inner);
self.pollings.push(receiver.start_polling());
@ -303,7 +265,6 @@ impl Module {
}
pub fn add_module(mut self, module: Module) -> Self {
self.message_types.extend(module.message_types);
self.pollings.extend(module.pollings);
self.receivers.extend(module.receivers);
@ -322,16 +283,6 @@ impl BusBuilder {
}
}
pub fn register_shared_message<
M: Message + Clone + serde::Serialize + serde::de::DeserializeOwned,
>(
self,
) -> Self {
let inner = self.inner.register_shared_message::<M>();
BusBuilder { inner }
}
pub fn register_relay<S: Relay + Send + Sync + 'static>(self, inner: S) -> Self {
let inner = self.inner.register_relay(inner);
@ -388,7 +339,7 @@ impl BusBuilder {
pub fn build(self) -> (Bus, impl Future<Output = ()>) {
let bus = Bus {
inner: Arc::new(BusInner::new(self.inner.receivers, self.inner.message_types)),
inner: Arc::new(BusInner::new(self.inner.receivers)),
};
let mut futs = Vec::with_capacity(self.inner.pollings.len() * 2);

View File

@ -3,7 +3,7 @@ use core::{
fmt,
};
use std::{borrow::Cow, sync::Arc};
use std::{alloc::Layout, borrow::Cow, sync::Arc};
pub trait MessageBounds: TypeTagged + fmt::Debug + Unpin + Send + Sync + 'static {}
impl<T: TypeTagged + fmt::Debug + Unpin + Send + Sync + 'static> MessageBounds for T {}
@ -17,6 +17,7 @@ pub trait TypeTagged {
fn type_tag(&self) -> TypeTag;
fn type_name(&self) -> Cow<str>;
fn type_layout(&self) -> Layout;
}
pub trait Message: MessageBounds {
@ -48,6 +49,9 @@ macro_rules! gen_impls {
fn type_name(&self) -> Cow<str> {
type_name::<$t>().into()
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
gen_impls!{ $($rest)* }
@ -64,6 +68,9 @@ macro_rules! gen_impls {
fn type_name(&self) -> Cow<str> {
type_name::<$t>().into()
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
};
}
@ -90,6 +97,9 @@ impl<T: TypeTagged> TypeTagged for Arc<T> {
fn type_name(&self) -> Cow<str> {
T::type_name(&*self)
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
impl<T: TypeTagged> TypeTagged for Box<T> {
@ -103,6 +113,9 @@ impl<T: TypeTagged> TypeTagged for Box<T> {
fn type_name(&self) -> Cow<str> {
T::type_name(&*self)
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
impl Message for () {
@ -226,6 +239,9 @@ mod tests {
fn type_name(&self) -> Cow<str> {
type_name::<Self>().into()
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
impl Message for Msg0 {
@ -279,6 +295,9 @@ mod tests {
fn type_name(&self) -> Cow<str> {
type_name::<Self>().into()
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
impl Message for Msg1 {
@ -341,6 +360,9 @@ mod tests {
fn type_name(&self) -> Cow<str> {
type_name::<Self>().into()
}
fn type_layout(&self) -> Layout {
Layout::for_value(self)
}
}
impl Message for Msg2 {

View File

@ -58,6 +58,9 @@ impl TypeTagged for GenericError {
fn type_name(&self) -> TypeTag {
type_name::<GenericError>().into()
}
fn type_layout(&self) -> std::alloc::Layout {
std::alloc::Layout::for_value(self)
}
}
#[derive(Debug, Error)]

View File

@ -1,8 +1,21 @@
use core::iter::FromIterator;
use std::ops::ControlFlow;
use crate::{error::StdSyncSendError, Bus, Message};
use async_trait::async_trait;
#[async_trait]
pub trait AsyncProducer<'a, M: Message>: Send {
type Error: StdSyncSendError;
type Context: Send + 'a;
type Response: Message;
async fn start(&'a self, msg: M, bus: &Bus) -> Result<Self::Context, Self::Error>;
async fn next(&'a self, ctx: &mut Self::Context, bus: &Bus) -> Result<ControlFlow<Self::Response>, Self::Error>;
async fn finish(&'a self, _ctx: Self::Context, _bus: &Bus) -> Result<(), Self::Error>;
}
pub trait Handler<M: Message>: Send + Sync {
type Error: StdSyncSendError;
type Response: Message;

View File

@ -7,6 +7,13 @@ pub mod receivers;
mod relay;
mod stats;
mod trait_object;
pub mod type_tag;
pub mod __reexport {
pub use serde;
pub use ctor;
}
#[macro_use]
extern crate log;
@ -25,7 +32,7 @@ use smallvec::SmallVec;
use std::{collections::{HashMap, HashSet}, sync::Arc};
use tokio::sync::Mutex;
use builder::{BusBuilder, MessageTypeDescriptor};
use builder::BusBuilder;
use error::{Error, SendError, StdSyncSendError};
use receiver::{Permit, Receiver};
use stats::Stats;
@ -39,8 +46,10 @@ pub use receiver::{
SendUntypedReceiver, TypeTagAccept,
};
pub use relay::Relay;
pub use ctor;
pub use type_tag::{register_shared_message, deserialize_shared_message};
pub type Untyped = Arc<dyn Any + Send + Sync>;
type LookupQuery = (TypeTag, Option<TypeTag>, Option<TypeTag>);
static ID_COUNTER: AtomicU64 = AtomicU64::new(1);
@ -62,7 +71,6 @@ impl Default for SendOptions {
pub struct BusInner {
receivers: HashSet<Receiver>,
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
lookup: HashMap<LookupQuery, SmallVec<[Receiver; 4]>>,
closed: AtomicBool,
maintain: Mutex<()>,
@ -71,7 +79,6 @@ pub struct BusInner {
impl BusInner {
pub(crate) fn new(
receivers: HashSet<Receiver>,
message_types: HashMap<TypeTag, MessageTypeDescriptor>,
) -> Self {
let mut lookup = HashMap::new();
for recv in receivers.iter() {
@ -85,6 +92,10 @@ impl BusInner {
.or_insert_with(HashSet::new)
.insert(recv.clone());
lookup.entry((msg.clone(), None, Some(err.clone())))
.or_insert_with(HashSet::new)
.insert(recv.clone());
lookup.entry((msg, Some(resp), Some(err)))
.or_insert_with(HashSet::new)
.insert(recv.clone());
@ -98,7 +109,6 @@ impl BusInner {
.collect();
Self {
message_types,
receivers,
lookup,
closed: AtomicBool::new(false),
@ -585,7 +595,9 @@ impl Bus {
if let Some(rs) = self.inner.lookup.get(&(tt.clone(), None, None))
.and_then(|rs| rs.first()) {
let msg = self.deserialize_message(tt.clone(), de)?;
let msg = deserialize_shared_message(tt.clone(), de)?;
Ok(rs.send_boxed(self, mid, msg.upcast_box(), false, rs.reserve(&tt).await)?)
} else {
Err(Error::NoReceivers)
@ -606,7 +618,7 @@ impl Bus {
let mut iter = self.select_receivers(tt.clone(), options, None, None, true);
if let Some(rc) = iter.next() {
let (mid, rx) = rc.add_response_waiter_boxed().unwrap();
let msg = self.deserialize_message(tt.clone(), de)?;
let msg = deserialize_shared_message(tt.clone(), de)?;
rc.send_boxed(
self,
@ -622,21 +634,6 @@ impl Bus {
}
}
pub fn deserialize_message(
&self,
tt: TypeTag,
de: &mut dyn erased_serde::Deserializer<'_>,
) -> Result<Box<dyn SharedMessage>, Error<Box<dyn Message>>> {
let md = self
.inner
.message_types
.get(&tt)
.ok_or(Error::TypeTagNotRegistered(tt))?;
md.deserialize_boxed(de)
.map_err(|err| err.specify::<Box<dyn Message>>())
}
pub fn stats(&self) -> impl Iterator<Item = Stats> + '_ {
self.inner
.receivers
@ -653,9 +650,9 @@ impl Bus {
eid: Option<TypeTag>,
is_req: bool,
) -> impl Iterator<Item = &Receiver> + '_ {
let vec = self.inner.lookup.get(&(tid.clone(), rid.clone(), eid.clone())).unwrap();
vec.iter()
self.inner.lookup.get(&(tid.clone(), rid.clone(), eid.clone()))
.into_iter()
.flatten()
.filter(move |r| r.accept(is_req, &tid, rid.as_ref(), eid.as_ref()))
.filter(move |r| match options {
SendOptions::Except(id) => id != r.id(),

82
src/type_tag.rs Normal file
View File

@ -0,0 +1,82 @@
use std::collections::HashMap;
use parking_lot::RwLock;
use crate::envelop::IntoSharedMessage;
use crate::error::Error;
use crate::{Message, SharedMessage, TypeTag};
static TYPE_REGISTRY: TypeRegistry = TypeRegistry::new();
type MessageDeserializerCallback = Box<
dyn Fn(&mut dyn erased_serde::Deserializer<'_>) -> Result<Box<dyn SharedMessage>, Error>
+ Send
+ Sync,
>;
pub struct MessageTypeDescriptor {
de: MessageDeserializerCallback,
}
impl MessageTypeDescriptor {
#[inline]
pub fn deserialize_boxed(
&self,
de: &mut dyn erased_serde::Deserializer<'_>,
) -> Result<Box<dyn SharedMessage>, Error> {
(self.de)(de)
}
}
#[derive(Default)]
pub struct TypeRegistry {
message_types: RwLock<Option<HashMap<TypeTag, MessageTypeDescriptor>>>,
}
impl TypeRegistry {
pub const fn new() -> Self {
Self {
message_types: parking_lot::const_rwlock(None)
}
}
pub fn deserialize(
&self,
tt: TypeTag,
de: &mut dyn erased_serde::Deserializer<'_>,
) -> Result<Box<dyn SharedMessage>, Error<Box<dyn Message>>> {
let guard = self
.message_types
.read();
let md = guard
.as_ref()
.ok_or_else(|| Error::TypeTagNotRegistered(tt.clone()))?
.get(&tt)
.ok_or(Error::TypeTagNotRegistered(tt))?;
md.deserialize_boxed(de)
.map_err(|err| err.specify::<Box<dyn Message>>())
}
pub fn register<M: Message + serde::Serialize + serde::de::DeserializeOwned>(&self) {
println!("insert {}", M::type_tag_());
self.message_types.write()
.get_or_insert_with(HashMap::new)
.insert(
M::type_tag_(),
MessageTypeDescriptor {
de: Box::new(move |de| Ok(M::deserialize(de)?.into_shared())),
},
);
}
}
#[inline]
pub fn deserialize_shared_message(tt: TypeTag, de: &mut dyn erased_serde::Deserializer<'_>) -> Result<Box<dyn SharedMessage>, Error<Box<dyn Message>>> {
TYPE_REGISTRY.deserialize(tt, de)
}
#[inline]
pub fn register_shared_message<M: Message + serde::Serialize + serde::de::DeserializeOwned>() {
TYPE_REGISTRY.register::<M>();
}

View File

@ -52,7 +52,6 @@ impl AsyncHandler<Msg> for TmpReceiver {
#[tokio::test]
async fn test() {
let (b, poller) = Bus::build()
.register_shared_message::<Msg>()
.register(TmpReceiver)
.subscribe_async::<Msg>(8, Default::default())
.done()

View File

@ -6,7 +6,7 @@ use std::sync::{
use async_trait::async_trait;
use messagebus::{
derive::{Error as MbError, Message},
error, AsyncHandler, Bus, Message, MessageBounds,
error, AsyncHandler, Bus, Message,
};
use thiserror::Error;
@ -22,14 +22,13 @@ impl<M: Message> From<error::Error<M>> for Error {
}
}
#[derive(Debug, Clone, Message)]
#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize, Message)]
#[message(clone, shared)]
struct Msg;
#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize, Message)]
#[message(clone, shared)]
struct SharedMsg<T: serde::Serialize + MessageBounds + Clone>(
#[serde(bound(deserialize = "T: serde::Deserialize<'de>"))] T,
);
struct SharedMsg<T>(T);
struct TmpReceiverContext {
sync1: AtomicBool,