Add filter-level and session-level contexts

This commit is contained in:
Rodolphe Breard 2019-09-18 20:43:46 +02:00
parent 995c0c35c1
commit fdc8bd3dc4
10 changed files with 233 additions and 84 deletions

View file

@ -9,8 +9,8 @@ This is a work in progress, the API is **not** stabilized yet.
- [x] Reports - [x] Reports
- [ ] Filters - [ ] Filters
- [ ] Filter-level context - [x] Filter-level context
- [ ] Session-level context - [x] Session-level context
[OpenSMTPD]: https://www.opensmtpd.org/ [OpenSMTPD]: https://www.opensmtpd.org/

View file

@ -13,7 +13,7 @@ mod attributes;
use attributes::OpenSmtpdAttributes; use attributes::OpenSmtpdAttributes;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote; use quote::quote;
use syn::{parse_macro_input, ExprArray, ItemFn, TypePath}; use syn::{parse_macro_input, ExprArray, ExprTry, ItemFn, ReturnType, TypePath};
macro_rules! parse_item { macro_rules! parse_item {
($item: expr, $type: ty) => { ($item: expr, $type: ty) => {
@ -26,31 +26,77 @@ macro_rules! parse_item {
}; };
} }
fn get_tokenstream(attr: TokenStream, input: TokenStream, type_str: &str) -> TokenStream { fn get_has_result(ret: &ReturnType) -> bool {
match ret {
ReturnType::Default => false,
ReturnType::Type(_, _) => true,
}
}
fn get_inner_call(nb_args: usize, has_output: bool, has_result: bool) -> String {
let mut call_params = Vec::new();
if has_output {
call_params.push("_output");
}
call_params.push("_entry");
if nb_args >= 2 {
call_params.push("_filter_ctx");
}
if nb_args >= 3 {
call_params.push("_session_ctx");
}
let call_params = call_params.join(", ");
let s = format!("inner_fn({})", &call_params);
if has_result {
return format!("{}?", s);
}
format!(
"(|{params}| -> Result<(), String> {{ {inner_fn}; Ok(()) }})({params})?",
params = &call_params,
inner_fn = s
)
}
fn get_tokenstream(
attr: TokenStream,
input: TokenStream,
type_str: &str,
has_output: bool,
) -> TokenStream {
let kind = parse_item!(type_str, TypePath); let kind = parse_item!(type_str, TypePath);
// Parse the procedural macro attributes
let attr = parse_macro_input!(attr as OpenSmtpdAttributes); let attr = parse_macro_input!(attr as OpenSmtpdAttributes);
let version = parse_item!(&attr.get_version(), TypePath); let version = parse_item!(&attr.get_version(), TypePath);
let subsystem = parse_item!(&attr.get_subsystem(), TypePath); let subsystem = parse_item!(&attr.get_subsystem(), TypePath);
let events = parse_item!(&attr.get_events(), ExprArray); let events = parse_item!(&attr.get_events(), ExprArray);
// Parse the user-supplied function
let item = parse_macro_input!(input as ItemFn); let item = parse_macro_input!(input as ItemFn);
let fn_name = &item.sig.ident; let fn_name = &item.sig.ident;
let fn_params = &item.sig.inputs; let fn_params = &item.sig.inputs;
let fn_return = &item.sig.output;
let fn_body = &item.block; let fn_body = &item.block;
let has_result = get_has_result(&item.sig.output);
let inner_call = parse_item!(
&get_inner_call(fn_params.len(), has_output, has_result),
ExprTry
);
// Build the new function
let output = quote! { let output = quote! {
fn #fn_name() -> opensmtpd::Handler { fn #fn_name() -> opensmtpd::Handler::<OpenSmtpdSessionContextType, OpenSmtpdFilterContextType> {
opensmtpd::Handler::new( opensmtpd::Handler::new(
#version, #version,
#kind, #kind,
#subsystem, #subsystem,
&#events, &#events,
|_output: &mut dyn opensmtpd::output::FilterOutput, _entry: &opensmtpd::entry::Entry,| { |_output: &mut dyn opensmtpd::output::FilterOutput, _entry: &opensmtpd::entry::Entry, _session_ctx: &mut OpenSmtpdSessionContextType, _filter_ctx: &mut OpenSmtpdFilterContextType| {
// TODO: look at `item.sig.output` and adapt the calling scheme. let inner_fn = |#fn_params| #fn_return {
// example: if no return, add `Ok(())`.
// https://docs.rs/syn/1.0.5/syn/struct.Signature.html
let inner_fn = |#fn_params| -> Result<(), String> {
#fn_body #fn_body
}; };
inner_fn(_entry) let _ = #inner_call;
Ok(())
}, },
) )
} }
@ -60,10 +106,10 @@ fn get_tokenstream(attr: TokenStream, input: TokenStream, type_str: &str) -> Tok
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn report(attr: TokenStream, input: TokenStream) -> TokenStream { pub fn report(attr: TokenStream, input: TokenStream) -> TokenStream {
get_tokenstream(attr, input, "opensmtpd::entry::Kind::Report") get_tokenstream(attr, input, "opensmtpd::entry::Kind::Report", false)
} }
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn filter(attr: TokenStream, input: TokenStream) -> TokenStream { pub fn filter(attr: TokenStream, input: TokenStream) -> TokenStream {
get_tokenstream(attr, input, "opensmtpd::entry::Kind::Filter") get_tokenstream(attr, input, "opensmtpd::entry::Kind::Filter", true)
} }

View file

@ -16,10 +16,14 @@ log = {version = "0.4", features = ["std"]}
nom = "5.0" nom = "5.0"
opensmtpd_derive = { path = "../opensmtpd-derive", version = "0.2" } opensmtpd_derive = { path = "../opensmtpd-derive", version = "0.2" }
[[example]]
name = "hello"
path = "examples/hello.rs"
[[example]] [[example]]
name = "echo" name = "echo"
path = "examples/echo.rs" path = "examples/echo.rs"
[[example]] [[example]]
name = "counter" name = "counter"
path = "examples/session_event_counter.rs" path = "examples/report_counter.rs"

View file

@ -1,23 +1,13 @@
use opensmtpd::entry::Entry; use opensmtpd::entry::Entry;
use opensmtpd::{report, simple_filter}; use opensmtpd::{register_no_context, report, simple_filter};
#[derive(Clone, Default)] register_no_context!();
struct MyContext {
nb: usize,
}
#[report(v1, smtp_in, match(all))] #[report(v1, smtp_in, match(all))]
fn echo_handler(entry: &Entry) -> Result<(), String> { fn echo(entry: &Entry) {
log::info!("TEST ENTRY: {:?}", entry); log::info!("New entry: {:?}", entry);
Ok(())
}
#[report(v1, smtp_in, match(link_disconnect))]
fn test(entry: &Entry) {
log::info!("HAZ LINK DISCONNECT: {:?}", entry);
Ok(()) // TODO: REMOVE ME!
} }
fn main() { fn main() {
simple_filter!(MyContext, [echo_handler, test]); simple_filter!([echo]);
} }

View file

@ -0,0 +1,13 @@
use opensmtpd::entry::Entry;
use opensmtpd::{register_no_context, report, simple_filter};
register_no_context!();
#[report(v1, smtp_in, match(link_connect))]
fn hello(entry: &Entry) {
log::info!("Hello {}!", entry.get_session_id());
}
fn main() {
simple_filter!([hello]);
}

View file

@ -0,0 +1,26 @@
use log;
use opensmtpd::entry::Entry;
use opensmtpd::{register_contexts, report, simple_filter};
#[derive(Clone, Default)]
struct MyCounter {
nb: usize,
}
register_contexts!(MyCounter, MyCounter);
#[report(v1, smtp_in, match(all))]
fn on_report(entry: &Entry, total: &mut MyCounter, session: &mut MyCounter) {
total.nb += 1;
session.nb += 1;
log::info!(
"Event received for session {}: {} (total: {})",
entry.get_session_id(),
session.nb,
total.nb
);
}
fn main() {
simple_filter!(MyCounter, MyCounter, [on_report]);
}

View file

@ -1,19 +0,0 @@
use log;
use opensmtpd::entry::Entry;
use opensmtpd::{report, simple_filter};
#[derive(Clone, Default)]
struct MyContext {
nb: usize,
}
#[report(v1, smtp_in, match(all))]
fn on_report(ctx: &mut MyContext, entry: &Entry) {
ctx.nb += 1;
log::info!("Event received: {}, {}", entry.get_session_id(), ctx.nb);
Ok(())
}
fn main() {
simple_filter!(MyContext, [on_report]);
}

View file

@ -16,6 +16,9 @@ use nom::Err::Incomplete;
use nom::IResult; use nom::IResult;
use std::str::FromStr; use std::str::FromStr;
pub type SessionId = u64;
pub type Token = u64;
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum Version { pub enum Version {
V1, V1,
@ -133,7 +136,7 @@ impl Entry {
} }
} }
pub fn get_session_id(&self) -> u64 { pub fn get_session_id(&self) -> SessionId {
match self { match self {
Entry::V1Report(r) => r.session_id, Entry::V1Report(r) => r.session_id,
Entry::V1Filter(f) => f.session_id, Entry::V1Filter(f) => f.session_id,
@ -153,7 +156,7 @@ pub struct V1Report {
pub timestamp: TimeVal, pub timestamp: TimeVal,
pub subsystem: Subsystem, pub subsystem: Subsystem,
pub event: Event, pub event: Event,
pub session_id: u64, pub session_id: SessionId,
pub params: Vec<String>, pub params: Vec<String>,
} }
@ -162,8 +165,8 @@ pub struct V1Filter {
pub timestamp: TimeVal, pub timestamp: TimeVal,
pub subsystem: Subsystem, pub subsystem: Subsystem,
pub event: Event, pub event: Event,
pub session_id: u64, pub session_id: SessionId,
pub token: u64, pub token: Token,
pub params: Vec<String>, pub params: Vec<String>,
} }
@ -219,12 +222,12 @@ fn parse_event(input: &str) -> IResult<&str, Event> {
))(input) ))(input)
} }
fn parse_token(input: &str) -> IResult<&str, u64> { fn parse_token(input: &str) -> IResult<&str, Token> {
map_res(hex_digit1, |s: &str| u64::from_str_radix(s, 16))(input) map_res(hex_digit1, |s: &str| Token::from_str_radix(s, 16))(input)
} }
fn parse_session_id(input: &str) -> IResult<&str, u64> { fn parse_session_id(input: &str) -> IResult<&str, SessionId> {
map_res(hex_digit1, |s: &str| u64::from_str_radix(s, 16))(input) map_res(hex_digit1, |s: &str| SessionId::from_str_radix(s, 16))(input)
} }
fn parse_param(input: &str) -> IResult<&str, String> { fn parse_param(input: &str) -> IResult<&str, String> {

View file

@ -12,36 +12,36 @@ use crate::output::FilterOutput;
use std::collections::HashSet; use std::collections::HashSet;
macro_rules! handle { macro_rules! handle {
($self: ident, $obj: ident, $version: expr, $kind: expr, $entry: ident, $output: ident) => {{ ($self: ident, $obj: ident, $version: expr, $kind: expr, $entry: ident, $output: ident, $session_ctx: ident, $filter_ctx: ident) => {{
if $self.version == $version if $self.version == $version
&& $self.kind == $kind && $self.kind == $kind
&& $self.subsystem == $obj.subsystem && $self.subsystem == $obj.subsystem
&& $self.events.contains(&$obj.event) && $self.events.contains(&$obj.event)
{ {
($self.action)($output, $entry)?; ($self.action)($output, $entry, $session_ctx, $filter_ctx)?;
} }
Ok(()) Ok(())
}}; }};
} }
type Callback = fn(&mut dyn FilterOutput, &Entry) -> Result<(), String>; type Callback<S, F> = fn(&mut dyn FilterOutput, &Entry, &mut S, &mut F) -> Result<(), String>;
#[derive(Clone)] #[derive(Clone)]
pub struct Handler { pub struct Handler<S, F> {
version: Version, version: Version,
pub(crate) kind: Kind, pub(crate) kind: Kind,
pub(crate) subsystem: Subsystem, pub(crate) subsystem: Subsystem,
pub(crate) events: HashSet<Event>, pub(crate) events: HashSet<Event>,
action: Callback, action: Callback<S, F>,
} }
impl Handler { impl<S, F> Handler<S, F> {
pub fn new( pub fn new(
version: Version, version: Version,
kind: Kind, kind: Kind,
subsystem: Subsystem, subsystem: Subsystem,
events: &[Event], events: &[Event],
action: Callback, action: Callback<S, F>,
) -> Self { ) -> Self {
Handler { Handler {
version, version,
@ -52,14 +52,34 @@ impl Handler {
} }
} }
pub fn send(&self, entry: &Entry, output: &mut dyn FilterOutput) -> Result<(), Error> { pub fn send(
&self,
entry: &Entry,
output: &mut dyn FilterOutput,
session_ctx: &mut S,
filter_ctx: &mut F,
) -> Result<(), Error> {
match entry { match entry {
Entry::V1Report(report) => { Entry::V1Report(report) => handle!(
handle!(self, report, Version::V1, Kind::Report, entry, output) self,
} report,
Entry::V1Filter(filter) => { Version::V1,
handle!(self, filter, Version::V1, Kind::Filter, entry, output) Kind::Report,
} entry,
output,
session_ctx,
filter_ctx
),
Entry::V1Filter(filter) => handle!(
self,
filter,
Version::V1,
Kind::Filter,
entry,
output,
session_ctx,
filter_ctx
),
} }
} }
} }

View file

@ -14,9 +14,9 @@ pub mod entry;
pub mod input; pub mod input;
pub mod output; pub mod output;
use crate::entry::{Kind, Subsystem}; use crate::entry::{Kind, SessionId, Subsystem};
use log; use log;
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::default::Default; use std::default::Default;
pub use crate::errors::Error; pub use crate::errors::Error;
@ -24,6 +24,41 @@ pub use crate::handler::Handler;
pub use crate::logger::SmtpdLogger; pub use crate::logger::SmtpdLogger;
pub use opensmtpd_derive::report; pub use opensmtpd_derive::report;
#[macro_export]
macro_rules! register_contexts {
($context: ty) => {
opensmtpd::register_contexts!($context, $context);
};
($session_context: ty, $filter_context: ty) => {
type OpenSmtpdFilterContextType = $filter_context;
type OpenSmtpdSessionContextType = $session_context;
};
}
#[macro_export]
macro_rules! register_filter_context_only {
($context: ty) => {
type OpenSmtpdFilterContextType = $context;
type OpenSmtpdSessionContextType = opensmtpd::NoContext;
};
}
#[macro_export]
macro_rules! register_session_context_only {
($context: ty) => {
type OpenSmtpdFilterContextType = opensmtpd::NoContext;
type OpenSmtpdSessionContextType = $context;
};
}
#[macro_export]
macro_rules! register_no_context {
() => {
type OpenSmtpdFilterContextType = opensmtpd::NoContext;
type OpenSmtpdSessionContextType = opensmtpd::NoContext;
};
}
#[macro_export] #[macro_export]
macro_rules! simple_filter { macro_rules! simple_filter {
($handlers: expr) => { ($handlers: expr) => {
@ -49,9 +84,14 @@ macro_rules! simple_filter {
let handlers = ($handlers) let handlers = ($handlers)
.iter() .iter()
.map(|f| f()) .map(|f| f())
.collect::<Vec<opensmtpd::Handler>>(); .collect::<Vec<opensmtpd::Handler<$sesion_ctx, $filter_ctx>>>();
let _ = opensmtpd::SmtpdLogger::new().set_level($log_level).init(); let _ = opensmtpd::SmtpdLogger::new().set_level($log_level).init();
opensmtpd::Filter::<opensmtpd::input::StdIn, opensmtpd::output::StdOut>::default() opensmtpd::Filter::<
opensmtpd::input::StdIn,
opensmtpd::output::StdOut,
$sesion_ctx,
$filter_ctx,
>::default()
.set_handlers(handlers.as_slice()) .set_handlers(handlers.as_slice())
.register_events() .register_events()
.run(); .run();
@ -84,39 +124,49 @@ macro_rules! register_events {
}; };
} }
#[derive(Default)] #[derive(Clone, Default)]
pub struct NoContext; pub struct NoContext;
pub struct Filter<I, O> pub struct Filter<I, O, S, F>
where where
I: crate::input::FilterInput + Default, I: crate::input::FilterInput + Default,
O: crate::output::FilterOutput + Default, O: crate::output::FilterOutput + Default,
S: Default,
F: Default,
{ {
input: I, input: I,
output: O, output: O,
handlers: Vec<Handler>, session_ctx: HashMap<SessionId, S>,
filter_ctx: F,
handlers: Vec<Handler<S, F>>,
} }
impl<I, O> Default for Filter<I, O> impl<I, O, S, F> Default for Filter<I, O, S, F>
where where
I: crate::input::FilterInput + Default, I: crate::input::FilterInput + Default,
O: crate::output::FilterOutput + Default, O: crate::output::FilterOutput + Default,
S: Default,
F: Default,
{ {
fn default() -> Self { fn default() -> Self {
Filter { Filter {
input: I::default(), input: I::default(),
output: O::default(), output: O::default(),
session_ctx: HashMap::new(),
filter_ctx: F::default(),
handlers: Vec::new(), handlers: Vec::new(),
} }
} }
} }
impl<I, O> Filter<I, O> impl<I, O, S, F> Filter<I, O, S, F>
where where
I: crate::input::FilterInput + Default, I: crate::input::FilterInput + Default,
O: crate::output::FilterOutput + Default, O: crate::output::FilterOutput + Default,
S: Clone + Default,
F: Clone + Default,
{ {
pub fn set_handlers(&mut self, handlers: &[Handler]) -> &mut Self { pub fn set_handlers(&mut self, handlers: &[Handler<S, F>]) -> &mut Self {
self.handlers = handlers.to_vec(); self.handlers = handlers.to_vec();
self self
} }
@ -150,14 +200,30 @@ where
match self.input.next() { match self.input.next() {
Ok(entry) => { Ok(entry) => {
log::debug!("{:?}", entry); log::debug!("{:?}", entry);
let session_id = entry.get_session_id();
let mut session_ctx = match self.session_ctx.get_mut(&session_id) {
Some(c) => c,
None => {
self.session_ctx.insert(session_id, S::default());
self.session_ctx.get_mut(&session_id).unwrap()
}
};
for h in self.handlers.iter() { for h in self.handlers.iter() {
match h.send(&entry, &mut self.output) { match h.send(
&entry,
&mut self.output,
&mut session_ctx,
&mut self.filter_ctx,
) {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
log::warn!("Warning: {}", e); log::warn!("Warning: {}", e);
} }
}; };
} }
if entry.is_disconnect() {
self.session_ctx.remove(&session_id);
}
} }
Err(e) => { Err(e) => {
fatal_error!(e); fatal_error!(e);