diff --git a/Cargo.toml b/Cargo.toml index 8847555..b7fed73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,8 @@ futures = { version = "0.3.28", default-features = false, features = ["std"] } log = { version = "0.4.17", default-features = false } nom = { version = "7.1.3", default-features = false } rand = { version = "0.8.5", default-features = false, features = ["std"] } -rsa = { version = "0.8.2", default-features = false, features = ["std"] } +rsa = { version = "0.8.2", default-features = false, features = ["std", "sha2"] } +sha2 = { version = "0.10.6", default-features = false, features = ["asm", "std"] } sqlx = { version = "0.6.3", default-features = false, features = ["runtime-tokio-native-tls", "macros", "migrate", "sqlite", "time"] } tokio = { version = "1.27.0", default-features = false, features = ["rt-multi-thread", "io-std", "io-util", "macros", "sync", "time", "process"] } uuid = { version = "1.3.1", default-features = false, features = ["v4", "fast-rng"] } diff --git a/src/action.rs b/src/action.rs index 3e141bb..d0dd742 100644 --- a/src/action.rs +++ b/src/action.rs @@ -11,7 +11,7 @@ use tokio::time::sleep; pub enum Action<'a> { ReadLine(Arc>), RotateKeys((&'a SqlitePool, &'a Config)), - SendMessage((Message, &'a Config)), + SendMessage((&'a SqlitePool, &'a Config, Message)), } pub enum ActionResult { @@ -36,8 +36,8 @@ pub async fn new_action(action: Action<'_>) -> ActionResult { sleep(duration).await; ActionResult::KeyRotation } - Action::SendMessage((msg, cnf)) => { - let msg_id = msg.sign_and_return(cnf).await; + Action::SendMessage((db, cnf, msg)) => { + let msg_id = msg.sign_and_return(db, cnf).await; ActionResult::MessageSent(msg_id) } } diff --git a/src/algorithm.rs b/src/algorithm.rs index f831c9a..6bd8873 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,7 +1,12 @@ +use anyhow::Result; use base64::{engine::general_purpose, Engine as _}; -use ed25519_dalek::SigningKey; +use ed25519_dalek::ed25519::SignatureEncoding; +use ed25519_dalek::{Signer, SigningKey as Ed25519SigningKey}; use rand::thread_rng; -use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey}; +use rsa::pkcs1v15::SigningKey as RsaSigningKey; +use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey, EncodePublicKey}; +use rsa::sha2::Sha256; +use rsa::signature::hazmat::PrehashSigner; use rsa::{RsaPrivateKey, RsaPublicKey}; use std::str::FromStr; @@ -32,10 +37,20 @@ impl Algorithm { } } - pub fn sign(&self, encoded_pk: &str, data: &[u8]) -> String { + pub fn sign(&self, encoded_pk: &str, data: &[u8]) -> Result> { + let pk = general_purpose::STANDARD.decode(encoded_pk)?; match self { - Self::Ed25519Sha256 => String::new(), - Self::Rsa2048Sha256 | Self::Rsa3072Sha256 | Self::Rsa4096Sha256 => String::new(), + Self::Ed25519Sha256 => { + let signing_key = Ed25519SigningKey::from_bytes(pk.as_slice().try_into()?); + let signature = signing_key.try_sign(data)?; + Ok(signature.to_vec()) + } + Self::Rsa2048Sha256 | Self::Rsa3072Sha256 | Self::Rsa4096Sha256 => { + let private_key = RsaPrivateKey::from_pkcs8_der(&pk)?; + let signing_key = RsaSigningKey::::new_with_prefix(private_key); + let signature = signing_key.sign_prehash(data)?; + Ok(signature.to_vec()) + } } } } @@ -73,7 +88,7 @@ impl FromStr for Algorithm { fn gen_ed25519_kp() -> (String, String) { let mut csprng = thread_rng(); - let priv_key = SigningKey::generate(&mut csprng); + let priv_key = Ed25519SigningKey::generate(&mut csprng); let pub_key = priv_key.verifying_key(); let priv_key = general_purpose::STANDARD.encode(priv_key.to_bytes()); let pub_key = general_purpose::STANDARD.encode(pub_key.to_bytes()); diff --git a/src/db.rs b/src/db.rs index 4c19fee..661d08b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -38,6 +38,14 @@ WHERE AND published IS FALSE ORDER BY not_after DESC LIMIT 1"; +pub const SELECT_LATEST_SIGNING_KEY: &str = "SELECT selector, private_key +FROM key_db +WHERE + sdid = $1 + AND algorithm = $2 + AND published IS FALSE +ORDER BY not_after DESC +LIMIT 1"; pub const SELECT_NEAREST_KEY_PUBLICATION: &str = "SELECT revocation FROM key_db WHERE published IS FALSE diff --git a/src/main.rs b/src/main.rs index c878920..6f826be 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ mod key; mod logs; mod message; mod parsed_message; +mod signature; mod stdin_reader; use action::{new_action, Action, ActionResult}; @@ -38,6 +39,8 @@ const DEFAULT_LIB_DIR: &str = env!("VARLIBDIR"); const DEFAULT_MSG_SIZE: usize = 1024 * 1024; const KEY_CHECK_MIN_DELAY: u64 = 60 * 60 * 3; const LOG_LEVEL_ENV_VAR: &str = "OPENSMTPD_FILTER_DKIMOUT_LOG_LEVEL"; +const SIG_RETRY_NB_RETRY: usize = 10; +const SIG_RETRY_SLEEP_TIME: u64 = 10; #[macro_export] macro_rules! display_bytes { @@ -117,7 +120,7 @@ async fn main_loop(cnf: &config::Config, db: &SqlitePool) { } else { log::debug!("message ready: {msg_id}"); if let Some(m) = messages.remove(&msg_id) { - actions.push(new_action(Action::SendMessage((m, cnf)))); + actions.push(new_action(Action::SendMessage((db, cnf, m)))); } } } @@ -127,7 +130,7 @@ async fn main_loop(cnf: &config::Config, db: &SqlitePool) { if !entry.is_end_of_message() { messages.insert(msg_id.clone(), msg); } else { - actions.push(new_action(Action::SendMessage((msg, cnf)))); + actions.push(new_action(Action::SendMessage((db, cnf, msg)))); } } } diff --git a/src/message.rs b/src/message.rs index 5b3d0bf..922ee4c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,7 +1,9 @@ use crate::config::Config; use crate::entry::Entry; use crate::parsed_message::ParsedMessage; +use crate::signature::Signature; use anyhow::Result; +use sqlx::SqlitePool; use tokio::io::{AsyncWriteExt, BufWriter}; pub const RETURN_SEP: &[u8] = b"|"; @@ -52,8 +54,12 @@ impl Message { self.nb_lines } - pub async fn sign_and_return(&self, cnf: &Config) -> String { - log::trace!("content: {}", crate::display_bytes!(&self.content)); + pub async fn sign_and_return(&self, db: &SqlitePool, cnf: &Config) -> String { + let msg_id = get_msg_id(&self.session_id, &self.token); + log::trace!( + "{msg_id}: content: {}", + crate::display_bytes!(&self.content) + ); match ParsedMessage::from_bytes(&self.content) { Ok(parsed_msg) => { log::trace!("mail parsed"); @@ -75,16 +81,31 @@ impl Message { "ParsedMessage: body: {}", crate::display_bytes!(parsed_msg.body) ); - // TODO: sign the message using DKIM + match Signature::new(db, cnf, &parsed_msg).await { + Ok(signature) => { + let sig_header = signature.get_header(); + if let Err(err) = self.print_sig_header(&sig_header).await { + log::error!("{msg_id}: unable to add the signature header: {err}"); + } + } + Err(err) => log::error!("{msg_id}: unable to sign message: {err}"), + } } Err(err) => { log::error!("{msg_id}: unable to parse message: {err}"); } } if let Err(err) = self.print_msg().await { - log::error!("unable to write message: {err}"); + log::error!("{msg_id}: unable to write message: {err}"); } - get_msg_id(&self.session_id, &self.token) + msg_id + } + + async fn print_sig_header(&self, sig_header: &str) -> Result<()> { + for line in sig_header.split("\r\n") { + self.print_line(line.as_bytes()).await?; + } + Ok(()) } async fn print_msg(&self) -> Result<()> { diff --git a/src/signature.rs b/src/signature.rs new file mode 100644 index 0000000..3034625 --- /dev/null +++ b/src/signature.rs @@ -0,0 +1,155 @@ +use crate::algorithm::Algorithm; +use crate::canonicalization::Canonicalization; +use crate::config::Config; +use crate::parsed_message::{ParsedHeader, ParsedMessage}; +use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose, Engine as _}; +use sha2::{Digest, Sha256}; +use sqlx::types::time::OffsetDateTime; +use sqlx::SqlitePool; +use tokio::time::{sleep, Duration}; + +pub struct Signature { + algorithm: Algorithm, + canonicalization: Canonicalization, + selector: String, + sdid: String, + timestamp: i64, + headers: Vec, + body_hash: Vec, + signature: Vec, +} + +impl Signature { + pub async fn new(db: &SqlitePool, cnf: &Config, msg: &ParsedMessage<'_>) -> Result { + let algorithm = cnf.algorithm(); + let sdid = get_sdid(cnf, msg)?; + let (selector, signing_key) = get_db_data(db, &sdid, algorithm).await?; + let mut sig = Self { + algorithm, + canonicalization: cnf.canonicalization(), + selector, + sdid, + timestamp: OffsetDateTime::now_utc().unix_timestamp(), + headers: get_headers(cnf, msg), + body_hash: Vec::new(), + signature: Vec::new(), + }; + sig.compute_body_hash::(msg); + let header_hash = sig.compute_header_hash::(msg); + sig.signature = algorithm.sign(&signing_key, &header_hash)?; + Ok(sig) + } + + pub fn get_header(&self) -> String { + format!( + "DKIM-Signature: v=1; a={algorithm}; c={canonicalization}; d={sdid};\r\n\tt={timestamp}; s={selector};\r\n\th={headers};\r\n\tbh={body_hash};\r\n\tb={signature}", + algorithm=self.algorithm.display(), + canonicalization=self.canonicalization.to_string(), + selector=self.selector, + sdid=self.sdid, + timestamp=self.timestamp, + headers=self.headers.join(":"), + body_hash=general_purpose::STANDARD.encode(&self.body_hash), + signature=general_purpose::STANDARD.encode(&self.signature), + ) + } + + fn compute_body_hash(&mut self, msg: &ParsedMessage<'_>) { + let mut hasher = H::new(); + let body = self.canonicalization.process_body(msg.body); + hasher.update(&body); + self.body_hash = hasher.finalize().to_vec(); + } + + fn compute_header_hash(&mut self, msg: &ParsedMessage<'_>) -> Vec { + let mut hasher = H::new(); + for header_name in &self.headers { + if let Some(raw_header) = get_header(msg, header_name) { + let header = self.canonicalization.process_header(raw_header.raw); + hasher.update(&header); + } + } + hasher.update(self.get_header().as_bytes()); + hasher.finalize().to_vec() + } +} + +fn get_sdid(cnf: &Config, msg: &ParsedMessage<'_>) -> Result { + if let Some(header) = get_header(msg, "from") { + if let Some(arb_pos) = header.value.iter().rposition(|&c| c == b'@') { + let name = &header.value[arb_pos + 1..]; + let end_pos = name + .iter() + .position(|&c| c == b'>') + .unwrap_or(name.len() - 2); + if let Ok(sdid) = String::from_utf8(name[..end_pos].to_vec()) { + if cnf.domains().contains(&sdid) { + return Ok(sdid); + } else { + return Err(anyhow!( + "unable to sign for a domain outside of the configured list: {sdid}" + )); + } + } + } + } + Err(anyhow!("unable to determine the SDID")) +} + +fn get_headers(cnf: &Config, msg: &ParsedMessage<'_>) -> Vec { + let nb_headers = cnf.headers().len() + cnf.headers_optional().len(); + let mut lst = Vec::with_capacity(nb_headers); + for header_name in cnf.headers() { + if let Some(name) = get_header_name(msg, header_name) { + lst.push(name); + } else { + lst.push(header_name.to_string()); + } + } + for header_name in cnf.headers_optional() { + if let Some(name) = get_header_name(msg, header_name) { + lst.push(name); + } + } + lst.sort(); + lst +} + +fn get_header_name(msg: &ParsedMessage<'_>, header_name: &str) -> Option { + match get_header(msg, header_name) { + Some(header) => String::from_utf8(header.name.to_vec()).ok(), + None => None, + } +} + +fn get_header<'a>( + msg: &'a ParsedMessage<'a>, + header_name: &'a str, +) -> Option<&'a ParsedHeader<'a>> { + let header_name = header_name.to_lowercase(); + msg.headers.iter().find(|&header| header.name_lower == header_name) +} + +async fn get_db_data( + db: &SqlitePool, + sdid: &str, + algorithm: Algorithm, +) -> Result<(String, String)> { + let mut ctn = 0; + loop { + let res: Option<(String, String)> = sqlx::query_as(crate::db::SELECT_LATEST_SIGNING_KEY) + .bind(sdid) + .bind(algorithm.to_string()) + .fetch_optional(db) + .await?; + if let Some((selector, private_key)) = res { + return Ok((selector, private_key)); + } + if ctn == crate::SIG_RETRY_NB_RETRY { + return Err(anyhow!("unable to retrieve key material")); + } + ctn += 1; + sleep(Duration::from_secs(crate::SIG_RETRY_SLEEP_TIME)).await; + } +}