diff --git a/README.md b/README.md index 9878e82..65952d9 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ This is a work in progress, the API is **not** stabilized yet. - [x] Reports - [ ] Filters -- [ ] Filter-level context -- [ ] Session-level context +- [x] Filter-level context +- [x] Session-level context [OpenSMTPD]: https://www.opensmtpd.org/ diff --git a/opensmtpd-derive/src/lib.rs b/opensmtpd-derive/src/lib.rs index 9f6a51a..0dde859 100644 --- a/opensmtpd-derive/src/lib.rs +++ b/opensmtpd-derive/src/lib.rs @@ -13,7 +13,7 @@ mod attributes; use attributes::OpenSmtpdAttributes; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, ExprArray, ItemFn, TypePath}; +use syn::{parse_macro_input, ExprArray, ExprTry, ItemFn, ReturnType, TypePath}; macro_rules! parse_item { ($item: expr, $type: ty) => { @@ -26,31 +26,77 @@ macro_rules! parse_item { }; } -fn get_tokenstream(attr: TokenStream, input: TokenStream, type_str: &str) -> TokenStream { +fn get_has_result(ret: &ReturnType) -> bool { + match ret { + ReturnType::Default => false, + ReturnType::Type(_, _) => true, + } +} + +fn get_inner_call(nb_args: usize, has_output: bool, has_result: bool) -> String { + let mut call_params = Vec::new(); + if has_output { + call_params.push("_output"); + } + call_params.push("_entry"); + if nb_args >= 2 { + call_params.push("_filter_ctx"); + } + if nb_args >= 3 { + call_params.push("_session_ctx"); + } + let call_params = call_params.join(", "); + let s = format!("inner_fn({})", &call_params); + if has_result { + return format!("{}?", s); + } + format!( + "(|{params}| -> Result<(), String> {{ {inner_fn}; Ok(()) }})({params})?", + params = &call_params, + inner_fn = s + ) +} + +fn get_tokenstream( + attr: TokenStream, + input: TokenStream, + type_str: &str, + has_output: bool, +) -> TokenStream { let kind = parse_item!(type_str, TypePath); + + // Parse the procedural macro attributes let attr = parse_macro_input!(attr as OpenSmtpdAttributes); let version = parse_item!(&attr.get_version(), TypePath); let subsystem = parse_item!(&attr.get_subsystem(), TypePath); let events = parse_item!(&attr.get_events(), ExprArray); + + // Parse the user-supplied function let item = parse_macro_input!(input as ItemFn); let fn_name = &item.sig.ident; let fn_params = &item.sig.inputs; + let fn_return = &item.sig.output; let fn_body = &item.block; + let has_result = get_has_result(&item.sig.output); + let inner_call = parse_item!( + &get_inner_call(fn_params.len(), has_output, has_result), + ExprTry + ); + + // Build the new function let output = quote! { - fn #fn_name() -> opensmtpd::Handler { + fn #fn_name() -> opensmtpd::Handler:: { opensmtpd::Handler::new( #version, #kind, #subsystem, &#events, - |_output: &mut dyn opensmtpd::output::FilterOutput, _entry: &opensmtpd::entry::Entry,| { - // TODO: look at `item.sig.output` and adapt the calling scheme. - // example: if no return, add `Ok(())`. - // https://docs.rs/syn/1.0.5/syn/struct.Signature.html - let inner_fn = |#fn_params| -> Result<(), String> { + |_output: &mut dyn opensmtpd::output::FilterOutput, _entry: &opensmtpd::entry::Entry, _session_ctx: &mut OpenSmtpdSessionContextType, _filter_ctx: &mut OpenSmtpdFilterContextType| { + let inner_fn = |#fn_params| #fn_return { #fn_body }; - inner_fn(_entry) + let _ = #inner_call; + Ok(()) }, ) } @@ -60,10 +106,10 @@ fn get_tokenstream(attr: TokenStream, input: TokenStream, type_str: &str) -> Tok #[proc_macro_attribute] pub fn report(attr: TokenStream, input: TokenStream) -> TokenStream { - get_tokenstream(attr, input, "opensmtpd::entry::Kind::Report") + get_tokenstream(attr, input, "opensmtpd::entry::Kind::Report", false) } #[proc_macro_attribute] pub fn filter(attr: TokenStream, input: TokenStream) -> TokenStream { - get_tokenstream(attr, input, "opensmtpd::entry::Kind::Filter") + get_tokenstream(attr, input, "opensmtpd::entry::Kind::Filter", true) } diff --git a/opensmtpd/Cargo.toml b/opensmtpd/Cargo.toml index c30ee80..64033b3 100644 --- a/opensmtpd/Cargo.toml +++ b/opensmtpd/Cargo.toml @@ -16,10 +16,14 @@ log = {version = "0.4", features = ["std"]} nom = "5.0" opensmtpd_derive = { path = "../opensmtpd-derive", version = "0.2" } +[[example]] +name = "hello" +path = "examples/hello.rs" + [[example]] name = "echo" path = "examples/echo.rs" [[example]] name = "counter" -path = "examples/session_event_counter.rs" +path = "examples/report_counter.rs" diff --git a/opensmtpd/examples/echo.rs b/opensmtpd/examples/echo.rs index 27fcf48..0b65ecb 100644 --- a/opensmtpd/examples/echo.rs +++ b/opensmtpd/examples/echo.rs @@ -1,23 +1,13 @@ use opensmtpd::entry::Entry; -use opensmtpd::{report, simple_filter}; +use opensmtpd::{register_no_context, report, simple_filter}; -#[derive(Clone, Default)] -struct MyContext { - nb: usize, -} +register_no_context!(); #[report(v1, smtp_in, match(all))] -fn echo_handler(entry: &Entry) -> Result<(), String> { - log::info!("TEST ENTRY: {:?}", entry); - Ok(()) -} - -#[report(v1, smtp_in, match(link_disconnect))] -fn test(entry: &Entry) { - log::info!("HAZ LINK DISCONNECT: {:?}", entry); - Ok(()) // TODO: REMOVE ME! +fn echo(entry: &Entry) { + log::info!("New entry: {:?}", entry); } fn main() { - simple_filter!(MyContext, [echo_handler, test]); + simple_filter!([echo]); } diff --git a/opensmtpd/examples/hello.rs b/opensmtpd/examples/hello.rs new file mode 100644 index 0000000..7467e3a --- /dev/null +++ b/opensmtpd/examples/hello.rs @@ -0,0 +1,13 @@ +use opensmtpd::entry::Entry; +use opensmtpd::{register_no_context, report, simple_filter}; + +register_no_context!(); + +#[report(v1, smtp_in, match(link_connect))] +fn hello(entry: &Entry) { + log::info!("Hello {}!", entry.get_session_id()); +} + +fn main() { + simple_filter!([hello]); +} diff --git a/opensmtpd/examples/report_counter.rs b/opensmtpd/examples/report_counter.rs new file mode 100644 index 0000000..c2c78f1 --- /dev/null +++ b/opensmtpd/examples/report_counter.rs @@ -0,0 +1,26 @@ +use log; +use opensmtpd::entry::Entry; +use opensmtpd::{register_contexts, report, simple_filter}; + +#[derive(Clone, Default)] +struct MyCounter { + nb: usize, +} + +register_contexts!(MyCounter, MyCounter); + +#[report(v1, smtp_in, match(all))] +fn on_report(entry: &Entry, total: &mut MyCounter, session: &mut MyCounter) { + total.nb += 1; + session.nb += 1; + log::info!( + "Event received for session {}: {} (total: {})", + entry.get_session_id(), + session.nb, + total.nb + ); +} + +fn main() { + simple_filter!(MyCounter, MyCounter, [on_report]); +} diff --git a/opensmtpd/examples/session_event_counter.rs b/opensmtpd/examples/session_event_counter.rs deleted file mode 100644 index b062bc1..0000000 --- a/opensmtpd/examples/session_event_counter.rs +++ /dev/null @@ -1,19 +0,0 @@ -use log; -use opensmtpd::entry::Entry; -use opensmtpd::{report, simple_filter}; - -#[derive(Clone, Default)] -struct MyContext { - nb: usize, -} - -#[report(v1, smtp_in, match(all))] -fn on_report(ctx: &mut MyContext, entry: &Entry) { - ctx.nb += 1; - log::info!("Event received: {}, {}", entry.get_session_id(), ctx.nb); - Ok(()) -} - -fn main() { - simple_filter!(MyContext, [on_report]); -} diff --git a/opensmtpd/src/entry.rs b/opensmtpd/src/entry.rs index 901b3cf..8706ae3 100644 --- a/opensmtpd/src/entry.rs +++ b/opensmtpd/src/entry.rs @@ -16,6 +16,9 @@ use nom::Err::Incomplete; use nom::IResult; use std::str::FromStr; +pub type SessionId = u64; +pub type Token = u64; + #[derive(Clone, Debug, Eq, PartialEq)] pub enum Version { V1, @@ -133,7 +136,7 @@ impl Entry { } } - pub fn get_session_id(&self) -> u64 { + pub fn get_session_id(&self) -> SessionId { match self { Entry::V1Report(r) => r.session_id, Entry::V1Filter(f) => f.session_id, @@ -153,7 +156,7 @@ pub struct V1Report { pub timestamp: TimeVal, pub subsystem: Subsystem, pub event: Event, - pub session_id: u64, + pub session_id: SessionId, pub params: Vec, } @@ -162,8 +165,8 @@ pub struct V1Filter { pub timestamp: TimeVal, pub subsystem: Subsystem, pub event: Event, - pub session_id: u64, - pub token: u64, + pub session_id: SessionId, + pub token: Token, pub params: Vec, } @@ -219,12 +222,12 @@ fn parse_event(input: &str) -> IResult<&str, Event> { ))(input) } -fn parse_token(input: &str) -> IResult<&str, u64> { - map_res(hex_digit1, |s: &str| u64::from_str_radix(s, 16))(input) +fn parse_token(input: &str) -> IResult<&str, Token> { + map_res(hex_digit1, |s: &str| Token::from_str_radix(s, 16))(input) } -fn parse_session_id(input: &str) -> IResult<&str, u64> { - map_res(hex_digit1, |s: &str| u64::from_str_radix(s, 16))(input) +fn parse_session_id(input: &str) -> IResult<&str, SessionId> { + map_res(hex_digit1, |s: &str| SessionId::from_str_radix(s, 16))(input) } fn parse_param(input: &str) -> IResult<&str, String> { diff --git a/opensmtpd/src/handler.rs b/opensmtpd/src/handler.rs index 10a2d51..542cce9 100644 --- a/opensmtpd/src/handler.rs +++ b/opensmtpd/src/handler.rs @@ -12,36 +12,36 @@ use crate::output::FilterOutput; use std::collections::HashSet; macro_rules! handle { - ($self: ident, $obj: ident, $version: expr, $kind: expr, $entry: ident, $output: ident) => {{ + ($self: ident, $obj: ident, $version: expr, $kind: expr, $entry: ident, $output: ident, $session_ctx: ident, $filter_ctx: ident) => {{ if $self.version == $version && $self.kind == $kind && $self.subsystem == $obj.subsystem && $self.events.contains(&$obj.event) { - ($self.action)($output, $entry)?; + ($self.action)($output, $entry, $session_ctx, $filter_ctx)?; } Ok(()) }}; } -type Callback = fn(&mut dyn FilterOutput, &Entry) -> Result<(), String>; +type Callback = fn(&mut dyn FilterOutput, &Entry, &mut S, &mut F) -> Result<(), String>; #[derive(Clone)] -pub struct Handler { +pub struct Handler { version: Version, pub(crate) kind: Kind, pub(crate) subsystem: Subsystem, pub(crate) events: HashSet, - action: Callback, + action: Callback, } -impl Handler { +impl Handler { pub fn new( version: Version, kind: Kind, subsystem: Subsystem, events: &[Event], - action: Callback, + action: Callback, ) -> Self { Handler { version, @@ -52,14 +52,34 @@ impl Handler { } } - pub fn send(&self, entry: &Entry, output: &mut dyn FilterOutput) -> Result<(), Error> { + pub fn send( + &self, + entry: &Entry, + output: &mut dyn FilterOutput, + session_ctx: &mut S, + filter_ctx: &mut F, + ) -> Result<(), Error> { match entry { - Entry::V1Report(report) => { - handle!(self, report, Version::V1, Kind::Report, entry, output) - } - Entry::V1Filter(filter) => { - handle!(self, filter, Version::V1, Kind::Filter, entry, output) - } + Entry::V1Report(report) => handle!( + self, + report, + Version::V1, + Kind::Report, + entry, + output, + session_ctx, + filter_ctx + ), + Entry::V1Filter(filter) => handle!( + self, + filter, + Version::V1, + Kind::Filter, + entry, + output, + session_ctx, + filter_ctx + ), } } } diff --git a/opensmtpd/src/lib.rs b/opensmtpd/src/lib.rs index 972458d..16334d6 100644 --- a/opensmtpd/src/lib.rs +++ b/opensmtpd/src/lib.rs @@ -14,9 +14,9 @@ pub mod entry; pub mod input; pub mod output; -use crate::entry::{Kind, Subsystem}; +use crate::entry::{Kind, SessionId, Subsystem}; use log; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::default::Default; pub use crate::errors::Error; @@ -24,6 +24,41 @@ pub use crate::handler::Handler; pub use crate::logger::SmtpdLogger; pub use opensmtpd_derive::report; +#[macro_export] +macro_rules! register_contexts { + ($context: ty) => { + opensmtpd::register_contexts!($context, $context); + }; + ($session_context: ty, $filter_context: ty) => { + type OpenSmtpdFilterContextType = $filter_context; + type OpenSmtpdSessionContextType = $session_context; + }; +} + +#[macro_export] +macro_rules! register_filter_context_only { + ($context: ty) => { + type OpenSmtpdFilterContextType = $context; + type OpenSmtpdSessionContextType = opensmtpd::NoContext; + }; +} + +#[macro_export] +macro_rules! register_session_context_only { + ($context: ty) => { + type OpenSmtpdFilterContextType = opensmtpd::NoContext; + type OpenSmtpdSessionContextType = $context; + }; +} + +#[macro_export] +macro_rules! register_no_context { + () => { + type OpenSmtpdFilterContextType = opensmtpd::NoContext; + type OpenSmtpdSessionContextType = opensmtpd::NoContext; + }; +} + #[macro_export] macro_rules! simple_filter { ($handlers: expr) => { @@ -49,12 +84,17 @@ macro_rules! simple_filter { let handlers = ($handlers) .iter() .map(|f| f()) - .collect::>(); + .collect::>>(); let _ = opensmtpd::SmtpdLogger::new().set_level($log_level).init(); - opensmtpd::Filter::::default() - .set_handlers(handlers.as_slice()) - .register_events() - .run(); + opensmtpd::Filter::< + opensmtpd::input::StdIn, + opensmtpd::output::StdOut, + $sesion_ctx, + $filter_ctx, + >::default() + .set_handlers(handlers.as_slice()) + .register_events() + .run(); }; } @@ -84,39 +124,49 @@ macro_rules! register_events { }; } -#[derive(Default)] +#[derive(Clone, Default)] pub struct NoContext; -pub struct Filter +pub struct Filter where I: crate::input::FilterInput + Default, O: crate::output::FilterOutput + Default, + S: Default, + F: Default, { input: I, output: O, - handlers: Vec, + session_ctx: HashMap, + filter_ctx: F, + handlers: Vec>, } -impl Default for Filter +impl Default for Filter where I: crate::input::FilterInput + Default, O: crate::output::FilterOutput + Default, + S: Default, + F: Default, { fn default() -> Self { Filter { input: I::default(), output: O::default(), + session_ctx: HashMap::new(), + filter_ctx: F::default(), handlers: Vec::new(), } } } -impl Filter +impl Filter where I: crate::input::FilterInput + Default, O: crate::output::FilterOutput + Default, + S: Clone + Default, + F: Clone + Default, { - pub fn set_handlers(&mut self, handlers: &[Handler]) -> &mut Self { + pub fn set_handlers(&mut self, handlers: &[Handler]) -> &mut Self { self.handlers = handlers.to_vec(); self } @@ -150,14 +200,30 @@ where match self.input.next() { Ok(entry) => { log::debug!("{:?}", entry); + let session_id = entry.get_session_id(); + let mut session_ctx = match self.session_ctx.get_mut(&session_id) { + Some(c) => c, + None => { + self.session_ctx.insert(session_id, S::default()); + self.session_ctx.get_mut(&session_id).unwrap() + } + }; for h in self.handlers.iter() { - match h.send(&entry, &mut self.output) { + match h.send( + &entry, + &mut self.output, + &mut session_ctx, + &mut self.filter_ctx, + ) { Ok(_) => {} Err(e) => { log::warn!("Warning: {}", e); } }; } + if entry.is_disconnect() { + self.session_ctx.remove(&session_id); + } } Err(e) => { fatal_error!(e);