From 06d3b8dfcaae2145e1338508fa9f38876a94653d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodolphe=20Br=C3=A9ard?= Date: Sun, 9 Apr 2023 15:13:18 +0200 Subject: [PATCH] Move the main loop to async --- Cargo.toml | 3 +- src/action.rs | 27 ++++++++++++++++ src/entry.rs | 23 ++++++++++++-- src/handshake.rs | 21 ++++++++----- src/main.rs | 75 +++++++++++++++++++++++++++++---------------- src/message.rs | 8 ++++- src/stdin_reader.rs | 6 ++-- 7 files changed, 123 insertions(+), 40 deletions(-) create mode 100644 src/action.rs diff --git a/Cargo.toml b/Cargo.toml index 93c055d..69e57a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ publish = false [dependencies] clap = { version = "4.1.13", default-features = false, features = ["std", "derive"] } env_logger = { version = "0.10.0", default-features = false } +futures = { version = "0.3.28", default-features = false, features = ["std"] } log = { version = "0.4.17", default-features = false } nom = { version = "7.1.3", default-features = false } -tokio = { version = "1.27.0", default-features = false, features = ["rt-multi-thread", "io-std", "io-util", "macros", "time", "process"] } +tokio = { version = "1.27.0", default-features = false, features = ["rt-multi-thread", "io-std", "io-util", "macros", "sync", "time", "process"] } diff --git a/src/action.rs b/src/action.rs new file mode 100644 index 0000000..b5b69cd --- /dev/null +++ b/src/action.rs @@ -0,0 +1,27 @@ +use crate::config::Config; +use crate::entry::read_entry; +use crate::message::Message; +use crate::stdin_reader::StdinReader; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub enum ActionResult { + EndOfStream, + MessageSent(String), + MessageSentError(String), + NewEntry(crate::entry::Entry), + NewEntryError(String), +} + +pub async fn new_action( + reader_lock: Option>>, + msg_tpl: Option<(Message, &Config)>, +) -> ActionResult { + if let Some(reader_lock) = reader_lock { + return read_entry(reader_lock).await; + } + if let Some((msg, cnf)) = msg_tpl { + return msg.sign_and_return(cnf).await; + } + ActionResult::MessageSentError("new_action: invalid parameters".to_string()) +} diff --git a/src/entry.rs b/src/entry.rs index 0058d53..701393c 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -1,5 +1,9 @@ +use crate::action::ActionResult; +use crate::stdin_reader::StdinReader; use nom::bytes::streaming::{tag, take_till, take_while1}; use nom::IResult; +use std::sync::Arc; +use tokio::sync::RwLock; #[derive(Debug)] pub struct Entry { @@ -10,7 +14,7 @@ pub struct Entry { impl Entry { pub fn get_msg_id(&self) -> String { - format!("{}.{}", self.session_id, self.token) + crate::message::get_msg_id(&self.session_id, &self.token) } pub fn get_session_id(&self) -> &str { @@ -29,12 +33,27 @@ impl Entry { self.data == vec![b'.'] } - pub fn from_bytes(input: &[u8]) -> Result { + fn from_bytes(input: &[u8]) -> Result { let (_, entry) = parse_entry(input).map_err(|e| format!("parsing error: {e}"))?; Ok(entry) } } +pub async fn read_entry(reader_lock: Arc>) -> ActionResult { + let mut reader = reader_lock.write().await; + log::trace!("reader lock on stdin locked"); + let line_res = reader.read_line().await; + drop(reader); + log::trace!("reader lock on stdin released"); + match line_res { + Some(line) => match Entry::from_bytes(&line) { + Ok(entry) => ActionResult::NewEntry(entry), + Err(err) => ActionResult::NewEntryError(err), + }, + None => ActionResult::EndOfStream, + } +} + fn is_eol(c: u8) -> bool { c == b'\n' } diff --git a/src/handshake.rs b/src/handshake.rs index a69989d..92da616 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -6,13 +6,20 @@ pub const CONFIG_TAG: &[u8] = b"config|"; pub async fn read_config(reader: &mut StdinReader) { loop { - let line = reader.read_line().await; - if line == CONFIG_END { - log::trace!("configuration is ready"); - return; - } - if !line.starts_with(CONFIG_TAG) { - log::warn!("invalid config line: {}", display_bytes!(line)); + match reader.read_line().await { + Some(line) => { + if line == CONFIG_END { + log::trace!("configuration is ready"); + return; + } + if !line.starts_with(CONFIG_TAG) { + log::warn!("invalid config line: {}", display_bytes!(line)); + } + } + None => { + log::debug!("end of input stream"); + std::process::exit(0); + } } } } diff --git a/src/main.rs b/src/main.rs index 59c4f23..f7e5c2a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod action; mod algorithm; mod canonicalization; mod config; @@ -8,12 +9,16 @@ mod message; mod parsed_message; mod stdin_reader; +use action::{new_action, ActionResult}; use algorithm::Algorithm; use canonicalization::CanonicalizationType; -use entry::Entry; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use message::Message; use std::collections::HashMap; +use std::sync::Arc; use stdin_reader::StdinReader; +use tokio::sync::RwLock; const DEFAULT_BUFF_SIZE: usize = 1024; const DEFAULT_CNF_ALGORITHM: Algorithm = Algorithm::Rsa2048Sha256; @@ -70,43 +75,61 @@ async fn main() -> Result<(), Box> { } async fn main_loop(cnf: &config::Config) { + let mut actions = FuturesUnordered::new(); let mut reader = StdinReader::new(); let mut messages: HashMap = HashMap::new(); handshake::read_config(&mut reader).await; handshake::register_filter(); log_messages!(messages); + let reader_lock = Arc::new(RwLock::new(reader)); + actions.push(new_action(Some(reader_lock.clone()), None)); loop { - match Entry::from_bytes(&reader.read_line().await) { - Ok(entry) => { - let msg_id = entry.get_msg_id(); - match messages.get_mut(&msg_id) { - Some(msg) => { - if !entry.is_end_of_message() { - log::debug!("new line in message: {msg_id}"); - msg.append_line(entry.get_data()); - } else { - log::debug!("message ready: {msg_id}"); - msg.sign_and_return(cnf).await; - messages.remove(&msg_id); - log::debug!("message removed: {msg_id}"); + if actions.is_empty() { + break; + } + if let Some(action_res) = actions.next().await { + match action_res { + ActionResult::EndOfStream => { + log::debug!("end of input stream"); + } + ActionResult::MessageSent(msg_id) => { + log::debug!("message removed: {msg_id}"); + } + ActionResult::MessageSentError(err) => { + log::error!("{err}"); + } + ActionResult::NewEntry(entry) => { + let msg_id = entry.get_msg_id(); + match messages.get_mut(&msg_id) { + Some(msg) => { + if !entry.is_end_of_message() { + log::debug!("new line in message: {msg_id}"); + msg.append_line(entry.get_data()); + } else { + log::debug!("message ready: {msg_id}"); + if let Some(m) = messages.remove(&msg_id) { + actions.push(new_action(None, Some((m, cnf)))); + } + } } - } - None => { - let msg = Message::from_entry(&entry); - if !entry.is_end_of_message() { + None => { + let msg = Message::from_entry(&entry); log::debug!("new message: {msg_id}"); - messages.insert(msg_id, msg); - } else { - log::debug!("empty new message: {msg_id}"); - msg.sign_and_return(cnf).await; + if !entry.is_end_of_message() { + messages.insert(msg_id.clone(), msg); + } else { + actions.push(new_action(None, Some((msg, cnf)))); + } } } + log_messages!(messages); + actions.push(new_action(Some(reader_lock.clone()), None)); + } + ActionResult::NewEntryError(err) => { + log::error!("invalid filter line: {err}"); + actions.push(new_action(Some(reader_lock.clone()), None)); } } - Err(err) => { - log::error!("invalid filter line: {err}"); - } } - log_messages!(messages); } } diff --git a/src/message.rs b/src/message.rs index c394508..925b46b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,4 @@ +use crate::action::ActionResult; use crate::config::Config; use crate::entry::Entry; use crate::parsed_message::ParsedMessage; @@ -51,7 +52,7 @@ impl Message { self.nb_lines } - pub async fn sign_and_return(&self, cnf: &Config) { + pub async fn sign_and_return(&self, cnf: &Config) -> ActionResult { log::trace!("content: {}", crate::display_bytes!(&self.content)); match ParsedMessage::from_bytes(&self.content) { Ok(parsed_msg) => { @@ -81,6 +82,7 @@ impl Message { } } self.print_msg().await; + ActionResult::MessageSent(get_msg_id(&self.session_id, &self.token)) } async fn print_msg(&self) { @@ -103,3 +105,7 @@ impl Message { stdout.flush().await.unwrap(); } } + +pub fn get_msg_id(session_id: &str, token: &str) -> String { + format!("{session_id}.{token}") +} diff --git a/src/stdin_reader.rs b/src/stdin_reader.rs index 1868173..58c30f0 100644 --- a/src/stdin_reader.rs +++ b/src/stdin_reader.rs @@ -14,7 +14,7 @@ impl StdinReader { } } - pub async fn read_line(&mut self) -> Vec { + pub async fn read_line(&mut self) -> Option> { self.buffer.clear(); log::trace!("reading line from stdin"); if self @@ -23,9 +23,9 @@ impl StdinReader { .await .unwrap() == 0 { - std::process::exit(0) + return None; } log::trace!("line read from stdin: {}", display_bytes!(self.buffer)); - self.buffer.clone() + Some(self.buffer.clone()) } }