diff --git a/opensmtpd-derive/src/lib.rs b/opensmtpd-derive/src/lib.rs index d6dc370..4aadd15 100644 --- a/opensmtpd-derive/src/lib.rs +++ b/opensmtpd-derive/src/lib.rs @@ -4,21 +4,61 @@ use proc_macro::TokenStream; use syn::{parse_macro_input, ItemFn}; use quote::quote; +fn get_type( + params: &syn::punctuated::Punctuated, +) -> Result<(Box, syn::Type), ()> { + match params.iter().count() { + 1 => { + let ctx = Box::new(syn::Type::Verbatim(syn::TypeVerbatim { + tts: quote! { + opensmtpd::NoContext + }, + })); + let cb = syn::Type::Verbatim(syn::TypeVerbatim { + tts: quote!{ opensmtpd::Callback::NoCtx }, + }); + Ok((ctx, cb)) + } + 2 => match params.iter().next().unwrap() { + syn::FnArg::Captured(ref a) => match &a.ty { + syn::Type::Reference(r) => { + let cb = match r.mutability { + Some(_) => syn::Type::Verbatim(syn::TypeVerbatim { + tts: quote!{ opensmtpd::Callback::CtxMut }, + }), + None => syn::Type::Verbatim(syn::TypeVerbatim { + tts: quote!{ opensmtpd::Callback::Ctx }, + }), + }; + Ok((r.elem.clone(), cb)) + } + _ => Err(()), + }, + _ => Err(()), + }, + _ => Err(()), + } +} + #[proc_macro_attribute] pub fn event(attr: TokenStream, input: TokenStream) -> TokenStream { let attr = attr.to_string(); let item = parse_macro_input!(input as ItemFn); let fn_name = &item.ident; let fn_params = &item.decl.inputs; + let (ctx_type, callback_type) = match get_type(fn_params) { + Ok(t) => t, + Err(_) => { + panic!(); + } + }; let fn_body = &item.block; let fn_output = &item.decl.output; let output = quote! { - // TODO: set the correct EventHandler type - fn #fn_name() -> opensmtpd::EventHandler { - // TODO: set the correct Callback type + fn #fn_name() -> opensmtpd::EventHandler<#ctx_type> { opensmtpd::EventHandler::new( #attr, - opensmtpd::Callback::CtxMut(|#fn_params| #fn_output #fn_body) + #callback_type(|#fn_params| #fn_output #fn_body) ) } }; diff --git a/opensmtpd/Cargo.toml b/opensmtpd/Cargo.toml index 627b272..d16eba1 100644 --- a/opensmtpd/Cargo.toml +++ b/opensmtpd/Cargo.toml @@ -20,5 +20,9 @@ opensmtpd_derive = { path = "../opensmtpd-derive", version="0.1" } name = "dummy" path = "examples/dummy.rs" +[[example]] +name = "counter" +path = "examples/session_event_counter.rs" + [dev-dependencies] env_logger = "0.6" diff --git a/opensmtpd/examples/dummy.rs b/opensmtpd/examples/dummy.rs index 48e5507..de1c079 100644 --- a/opensmtpd/examples/dummy.rs +++ b/opensmtpd/examples/dummy.rs @@ -1,14 +1,14 @@ use env_logger::{Builder, Env}; use log::{debug, info}; -use opensmtpd::{event, handlers, Entry, NoContext, SmtpIn}; +use opensmtpd::{event, handlers, Entry, SmtpIn}; #[event(Any)] -fn on_event(_context: &mut NoContext, entry: &Entry) { +fn on_event(entry: &Entry) { debug!("Event received: {:?}", entry); } #[event(LinkConnect)] -fn on_connect(_context: &mut NoContext, entry: &Entry) { +fn on_connect(entry: &Entry) { info!("New client on session {:x}.", entry.session_id); } diff --git a/opensmtpd/examples/session_event_counter.rs b/opensmtpd/examples/session_event_counter.rs new file mode 100644 index 0000000..f143370 --- /dev/null +++ b/opensmtpd/examples/session_event_counter.rs @@ -0,0 +1,19 @@ +use env_logger::{Builder, Env}; +use log::info; +use opensmtpd::{event, handlers, Entry, SmtpIn}; + +#[derive(Clone, Default)] +struct MyContext { + nb: usize, +} + +#[event(Any)] +fn on_event(ctx: &mut MyContext, entry: &Entry) { + ctx.nb += 1; + info!("Event received: {}, {}", entry.session_id, ctx.nb); +} + +fn main() { + Builder::from_env(Env::default().default_filter_or("debug")).init(); + SmtpIn::new().event_handlers(handlers!(on_event)).run(); +} diff --git a/opensmtpd/src/event_handlers.rs b/opensmtpd/src/event_handlers.rs index 69ef4dd..84c3cb8 100644 --- a/opensmtpd/src/event_handlers.rs +++ b/opensmtpd/src/event_handlers.rs @@ -65,11 +65,7 @@ mod test { #[test] fn test_eventhandler_build_noctx() { - // TODO: Remove the :: - EventHandler::new( - "Any", - Callback::NoCtx::(|_entry: &Entry| {}), - ); + EventHandler::new("Any", Callback::NoCtx::(|_entry: &Entry| {})); } #[test] diff --git a/opensmtpd/src/lib.rs b/opensmtpd/src/lib.rs index 0c1d5de..487bc72 100644 --- a/opensmtpd/src/lib.rs +++ b/opensmtpd/src/lib.rs @@ -79,16 +79,14 @@ impl SmtpIn { let mut evts = Vec::new(); for eh in self.event_handlers.iter() { match eh.event { - MatchEvent::Evt(ref v) => { - for e in v.iter() { - evts.push(e); - } + MatchEvent::Evt(ref v) => for e in v.iter() { + evts.push(e); }, MatchEvent::All => { println!("register|report|smtp-in|*"); evts.clear(); - break ; - }, + break; + } } } evts.dedup();