diff --git a/Cargo.toml b/Cargo.toml index 7f3a3c4..eca93cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,10 +8,15 @@ license = "MIT OR Apache-2.0" publish = false [dependencies] +base64 = { version = "0.21.0", default-features = false, features = ["std"] } clap = { version = "4.1.13", default-features = false, features = ["std", "derive"] } +ed25519-dalek = { version = "2.0.0-rc.2", default-features = false, features = ["fast", "rand_core", "std"] } env_logger = { version = "0.10.0", default-features = false } futures = { version = "0.3.28", default-features = false, features = ["std"] } log = { version = "0.4.17", default-features = false } nom = { version = "7.1.3", default-features = false } -sqlx = { version = "0.6.3", default-features = false, features = ["runtime-tokio-rustls", "macros", "migrate", "sqlite", "uuid"] } +rand = { version = "0.8.5", default-features = false, features = ["std"] } +rsa = { version = "0.8.2", default-features = false, features = ["std"] } +sqlx = { version = "0.6.3", default-features = false, features = ["runtime-tokio-rustls", "macros", "migrate", "sqlite", "time"] } tokio = { version = "1.27.0", default-features = false, features = ["rt-multi-thread", "io-std", "io-util", "macros", "sync", "time", "process"] } +uuid = { version = "1.3.1", default-features = false, features = ["v4", "fast-rng"] } diff --git a/src/action.rs b/src/action.rs index 845fccb..fbf44a4 100644 --- a/src/action.rs +++ b/src/action.rs @@ -6,6 +6,7 @@ use crate::stdin_reader::StdinReader; use sqlx::SqlitePool; use std::sync::Arc; use tokio::sync::RwLock; +use tokio::time::sleep; pub enum Action<'a> { ReadLine(Arc>), @@ -31,7 +32,8 @@ pub async fn new_action(action: Action<'_>) -> ActionResult { None => ActionResult::EndOfStream, }, Action::RotateKeys((db, cnf)) => { - key_rotation(db, cnf).await; + let duration = key_rotation(db, cnf).await; + sleep(duration).await; ActionResult::KeyRotation } Action::SendMessage((msg, cnf)) => { diff --git a/src/algorithm.rs b/src/algorithm.rs index e490b91..1bdb04f 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,3 +1,9 @@ +use base64::{engine::general_purpose, Engine as _}; +use ed25519_dalek::{SigningKey, VerifyingKey}; +use rand::prelude::ThreadRng; +use rand::thread_rng; +use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey}; +use rsa::{RsaPrivateKey, RsaPublicKey}; use std::str::FromStr; #[derive(Clone, Copy, Debug)] @@ -17,6 +23,22 @@ impl Algorithm { } } } + + pub fn gen_keys(&self) -> (String, String) { + match self { + Self::Ed25519Sha256 => gen_ed25519_kp(), + Self::Rsa2048Sha256 => gen_rsa_kp(2048), + Self::Rsa3072Sha256 => gen_rsa_kp(3072), + Self::Rsa4096Sha256 => gen_rsa_kp(4096), + } + } + + pub fn sign(&self, encoded_pk: &str, data: &[u8]) -> String { + match self { + Self::Ed25519Sha256 => String::new(), + Self::Rsa2048Sha256 | Self::Rsa3072Sha256 | Self::Rsa4096Sha256 => String::new(), + } + } } impl Default for Algorithm { @@ -49,3 +71,37 @@ impl FromStr for Algorithm { } } } + +fn gen_ed25519_kp() -> (String, String) { + let mut csprng = thread_rng(); + let priv_key = SigningKey::generate(&mut csprng); + let pub_key = priv_key.verifying_key(); + let priv_key = general_purpose::STANDARD.encode(priv_key.to_bytes()); + let pub_key = general_purpose::STANDARD.encode(pub_key.to_bytes()); + (priv_key, pub_key) +} + +fn gen_rsa_kp(bits: usize) -> (String, String) { + let mut csprng = thread_rng(); + loop { + if let Ok(priv_key) = RsaPrivateKey::new(&mut csprng, bits) { + let pub_key = RsaPublicKey::from(&priv_key); + let priv_key = match priv_key.to_pkcs8_der() { + Ok(d) => d, + Err(_) => { + continue; + } + }; + let pub_key = match pub_key.to_public_key_der() { + Ok(d) => d, + Err(_) => { + continue; + } + }; + let priv_key = general_purpose::STANDARD.encode(priv_key.as_bytes()); + let pub_key = general_purpose::STANDARD.encode(pub_key.as_bytes()); + return (priv_key, pub_key); + } + log::trace!("failed to generate an RSA-{bits} key"); + } +} diff --git a/src/key.rs b/src/key.rs index 32a33df..c517c30 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,7 +1,105 @@ use crate::config::Config; +use crate::Algorithm; +use sqlx::types::time::OffsetDateTime; use sqlx::SqlitePool; +use tokio::time::Duration; +use uuid::Uuid; -pub async fn key_rotation(db: &SqlitePool, cnf: &Config) { - use tokio::time::{sleep, Duration}; - sleep(Duration::from_secs(10)).await; +const INSERT_KEY: &str = "INSERT INTO key_db ( + selector, + sdid, + algorithm, + creation, + not_after, + revocation, + private_key, + public_key +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8 +)"; +const SELECT_LATEST_KEY: &str = "SELECT not_after +FROM key_db +WHERE + sdid = $1 + AND algorithm = $2 +ORDER BY not_after DESC +LIMIT 1"; + +pub async fn key_rotation(db: &SqlitePool, cnf: &Config) -> Duration { + let mut durations = Vec::with_capacity(cnf.domains().len()); + let expiration = cnf + .expiration() + .map(Duration::from_secs) + .unwrap_or_else(|| Duration::from_secs(cnf.cryptoperiod().get() / 10)); + for domain in cnf.domains() { + if let Ok(d) = renew_key_if_expired(db, cnf, domain, cnf.algorithm(), expiration).await { + durations.push(d); + } + } + durations.sort(); + durations.reverse(); + durations.pop().unwrap_or(Duration::from_secs(3600)) +} + +async fn renew_key_if_expired( + db: &SqlitePool, + cnf: &Config, + domain: &str, + algorithm: Algorithm, + expiration: Duration, +) -> Result { + let res: Option<(OffsetDateTime,)> = sqlx::query_as(SELECT_LATEST_KEY) + .bind(domain) + .bind(algorithm.to_string()) + .fetch_optional(db) + .await + .map_err(|_| ())?; + match res { + Some((not_after,)) => { + log::debug!("{domain}: key is valid until {not_after}"); + if not_after - expiration <= OffsetDateTime::now_utc() { + generate_key(db, cnf, domain, algorithm).await?; + } + } + None => { + log::debug!("no key found for domain {domain}"); + generate_key(db, cnf, domain, algorithm).await?; + } + }; + Ok(Duration::from_secs(10)) +} + +async fn generate_key( + db: &SqlitePool, + cnf: &Config, + domain: &str, + algorithm: Algorithm, +) -> Result<(), ()> { + let selector = format!("dkim-{}", Uuid::new_v4().simple()); + let now = OffsetDateTime::now_utc(); + let not_after = now + Duration::from_secs(cnf.cryptoperiod().get()); + let revocation = not_after + Duration::from_secs(cnf.revocation()); + let (priv_key, pub_key) = algorithm.gen_keys(); + sqlx::query(INSERT_KEY) + .bind(selector) + .bind(domain) + .bind(algorithm.to_string()) + .bind(now) + .bind(not_after) + .bind(revocation) + .bind(priv_key) + .bind(pub_key) + .execute(db) + .await + .map_err(|_| ())?; + // TODO: dns_update_cmd + log::debug!("{domain}: new {} key generated", algorithm.to_string()); + Ok(()) }