diff --git a/examples/dummy.rs b/examples/dummy.rs index b905a67..6cab936 100644 --- a/examples/dummy.rs +++ b/examples/dummy.rs @@ -1,6 +1,3 @@ fn main() { - match opensmtpd::dispatch() { - Ok(_) => {} - Err(e) => eprintln!("Error: {}", e.as_str()), - } + opensmtpd::run(); } diff --git a/src/errors.rs b/src/errors.rs index 63a49b1..02803a7 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,12 +15,8 @@ impl Error { Error::new(&msg) } - pub fn new_param(param: &str, msg: &str) -> Self { - Error::new(&format!("{}: {}", param, msg)) - } - - pub fn as_str(&self) -> &str { - &self.message + pub fn display(&self) { + eprintln!("Error: {}", self.message); } } diff --git a/src/lib.rs b/src/lib.rs index 5e5ffff..ce71a66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,33 +8,79 @@ use std::io; use std::sync::mpsc; use std::thread; -pub fn dispatch() -> Result<(), Error> { - let mut sessions = HashMap::new(); - loop { - let mut input = String::new(); - let nb = io::stdin().read_line(&mut input)?; - if nb == 0 { - continue; - } - let entry = Entry::from_str(input.as_str())?; - let channel = match sessions.get(&entry.session_id) { - Some(c) => c, - None => { - let (tx, rx) = mpsc::channel(); - let name = entry.session_id.to_string(); - thread::Builder::new().name(name).spawn(move || { - for e in rx.iter() { - println!( - "Debug: thread {}: {:?}", - thread::current().name().unwrap(), - e - ); - } - })?; - sessions.insert(entry.session_id, tx); - sessions.get(&entry.session_id).unwrap() - } - }; - channel.send(entry)?; +/// Read a line from the standard input. +/// Since EOF should not append, it is considered as an error. +fn read() -> Result { + let mut input = String::new(); + let nb = io::stdin().read_line(&mut input)?; + match nb { + 0 => Err(Error::new("end of file")), + _ => Ok(input), + } +} + +/// Dispatch the entry into its session's thread. If such thread does not +/// already exists, creates it. +fn dispatch( + sessions: &mut HashMap, thread::JoinHandle<()>)>, + input: &str, +) -> Result<(), Error> { + let entry = Entry::from_str(input)?; + let channel = match sessions.get(&entry.session_id) { + Some((r, _)) => r, + None => { + let (tx, rx) = mpsc::channel(); + let name = entry.session_id.to_string(); + let handle = thread::Builder::new().name(name).spawn(move || { + println!("New thread: {}", thread::current().name().unwrap()); + for e in rx.iter() { + println!( + "Debug: thread {}: {:?}", + thread::current().name().unwrap(), + e + ); + } + })?; + sessions.insert(entry.session_id, (tx, handle)); + let (r, _) = sessions.get(&entry.session_id).unwrap(); + r + } + }; + channel.send(entry)?; + Ok(()) +} + +/// Allow each child thread to exit gracefully. First, the session table is +/// drained so all the references to the senders are dropped, which will +/// cause the receivers threads to exit. Then, we uses the join handlers in +/// order to wait for the actual exit. +fn graceful_exit_children( + sessions: &mut HashMap, thread::JoinHandle<()>)>, +) { + let mut handles = Vec::new(); + for (_, (_, h)) in sessions.drain() { + handles.push(h); + } + for h in handles { + let _ = h.join(); + } +} + +/// Run the infinite loop that will read and process input from stdin. +pub fn run() { + let mut sessions = HashMap::new(); + loop { + let line = match read() { + Ok(l) => l, + Err(e) => { + graceful_exit_children(&mut sessions); + e.display(); + std::process::exit(1); + } + }; + match dispatch(&mut sessions, &line) { + Ok(_) => {} + Err(e) => e.display(), + } } }