diff --git a/src/encryption.rs b/src/encryption.rs index 0dd8105..99e7257 100644 --- a/src/encryption.rs +++ b/src/encryption.rs @@ -1,7 +1,8 @@ use crate::canonicalization::{canonicalize, join_canonicalized_str}; use crate::error::Result; -use crate::kdf::derive_key; +use crate::kdf::{derive_key, KeyContext}; use crate::{storage, InputKeyMaterialList}; +use std::time::{SystemTime, UNIX_EPOCH}; pub(crate) type DecryptionFunction = dyn Fn(&[u8], &EncryptedData, &str) -> Result>; pub(crate) type EncryptionFunction = dyn Fn(&[u8], &[u8], &str) -> Result; @@ -13,36 +14,45 @@ pub(crate) struct EncryptedData { } #[inline] -fn generate_aad(key_context: &[&str], data_context: &[impl AsRef<[u8]>]) -> String { - let key_context_canon = canonicalize(key_context); +fn generate_aad( + key_context: &KeyContext, + data_context: &[impl AsRef<[u8]>], + ts: Option, +) -> String { + let key_context_canon = canonicalize(&key_context.get_value(ts)); let data_context_canon = canonicalize(data_context); join_canonicalized_str(&key_context_canon, &data_context_canon) } pub fn encrypt( ikml: &InputKeyMaterialList, - key_context: &[&str], + key_context: &KeyContext, data: impl AsRef<[u8]>, data_context: &[impl AsRef<[u8]>], ) -> Result { + let ts = if key_context.is_periodic() { + Some(SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs()) + } else { + None + }; let ikm = ikml.get_latest_ikm()?; - let key = derive_key(ikm, key_context); - let aad = generate_aad(key_context, data_context); + let key = derive_key(ikm, key_context, ts); + let aad = generate_aad(key_context, data_context, ts); let encryption_function = ikm.scheme.get_encryption(); let encrypted_data = encryption_function(&key, data.as_ref(), &aad)?; - Ok(storage::encode_cipher(ikm.id, &encrypted_data)) + Ok(storage::encode_cipher(ikm.id, &encrypted_data, ts)) } pub fn decrypt( ikml: &InputKeyMaterialList, - key_context: &[&str], + key_context: &KeyContext, stored_data: &str, data_context: &[impl AsRef<[u8]>], ) -> Result> { - let (ikm_id, encrypted_data) = storage::decode_cipher(stored_data)?; + let (ikm_id, encrypted_data, ts) = storage::decode_cipher(stored_data)?; let ikm = ikml.get_ikm_by_id(ikm_id)?; - let key = derive_key(ikm, key_context); - let aad = generate_aad(key_context, data_context); + let key = derive_key(ikm, key_context, ts); + let aad = generate_aad(key_context, data_context, ts); let decryption_function = ikm.scheme.get_decryption(); decryption_function(&key, &encrypted_data, &aad) } @@ -50,9 +60,10 @@ pub fn decrypt( #[cfg(test)] mod tests { use super::*; + use crate::KeyContext; const TEST_DATA: &[u8] = b"Lorem ipsum dolor sit amet."; - const TEST_KEY_CTX: &[&str] = &["db_name", "table_name", "column_name"]; + const TEST_KEY_CTX: [&str; 3] = ["db_name", "table_name", "column_name"]; const TEST_DATA_CTX: &[&str] = &["018db876-3d9d-79af-9460-55d17da991d8"]; const EMPTY_DATA_CTX: &[[u8; 0]] = &[]; @@ -65,16 +76,18 @@ mod tests { #[test] fn encrypt_decrypt_no_context() { + let ctx = KeyContext::from([]); + // Encrypt let lst = get_ikm_lst(); - let res = encrypt(&lst, &[], TEST_DATA, EMPTY_DATA_CTX); + let res = encrypt(&lst, &ctx, TEST_DATA, EMPTY_DATA_CTX); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, &[], &ciphertext, EMPTY_DATA_CTX); + let res = decrypt(&lst, &ctx, &ciphertext, EMPTY_DATA_CTX); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -84,14 +97,14 @@ mod tests { fn encrypt_decrypt_with_context() { // Encrypt let lst = get_ikm_lst(); - let res = encrypt(&lst, TEST_KEY_CTX, TEST_DATA, TEST_DATA_CTX); + let res = encrypt(&lst, &TEST_KEY_CTX.into(), TEST_DATA, TEST_DATA_CTX); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, TEST_KEY_CTX, &ciphertext, TEST_DATA_CTX); + let res = decrypt(&lst, &TEST_KEY_CTX.into(), &ciphertext, TEST_DATA_CTX); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); diff --git a/src/error.rs b/src/error.rs index a9919ba..a465d47 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,6 +21,8 @@ pub enum Error { ParsingEncodedDataInvalidIkmListLen(usize), #[error("parsing error: encoded data: invalid number of parts: got {1} instead of {0}")] ParsingEncodedDataInvalidPartLen(usize, usize), + #[error("parsing error: encoded data: invalid timestamp: {0:?}")] + ParsingEncodedDataInvalidTimestamp(Vec), #[error("parsing error: scheme: {0}: unknown scheme")] ParsingSchemeUnknownScheme(SchemeSerializeType), #[error("unable to generate random values: {0}")] diff --git a/src/kdf.rs b/src/kdf.rs index 0ffc721..5f47d52 100644 --- a/src/kdf.rs +++ b/src/kdf.rs @@ -3,8 +3,38 @@ use crate::ikm::InputKeyMaterial; pub(crate) type KdfFunction = dyn Fn(&str, &[u8]) -> Vec; -pub(crate) fn derive_key(ikm: &InputKeyMaterial, key_context: &[&str]) -> Vec { - let key_context = canonicalize(key_context); +pub struct KeyContext { + ctx: Vec, + periodicity: Option, +} + +impl KeyContext { + pub(crate) fn get_value(&self, ts: Option) -> Vec> { + let mut ret: Vec> = self.ctx.iter().map(|s| s.as_bytes().to_vec()).collect(); + if let Some(p) = self.periodicity { + let ts = ts.unwrap_or(0); + let c = ts % p; + ret.push(c.to_le_bytes().to_vec()); + } + ret + } + + pub(crate) fn is_periodic(&self) -> bool { + self.periodicity.is_some() + } +} + +impl From<[&str; N]> for KeyContext { + fn from(ctx: [&str; N]) -> Self { + Self { + ctx: ctx.iter().map(|s| s.to_string()).collect(), + periodicity: None, + } + } +} + +pub(crate) fn derive_key(ikm: &InputKeyMaterial, ctx: &KeyContext, ts: Option) -> Vec { + let key_context = canonicalize(&ctx.get_value(ts)); let kdf = ikm.scheme.get_kdf(); kdf(&key_context, &ikm.content) } @@ -12,6 +42,7 @@ pub(crate) fn derive_key(ikm: &InputKeyMaterial, key_context: &[&str]) -> Vec Result { Ok(ret) } -pub(crate) fn encode_cipher(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String { +pub(crate) fn encode_cipher( + ikm_id: IkmId, + encrypted_data: &EncryptedData, + ts: Option, +) -> String { let mut ret = String::new(); ret += &encode_data(&ikm_id.to_le_bytes()); ret += STORAGE_SEPARATOR; ret += &encode_data(&encrypted_data.nonce); ret += STORAGE_SEPARATOR; ret += &encode_data(&encrypted_data.ciphertext); + if let Some(ts) = ts { + ret += STORAGE_SEPARATOR; + ret += &encode_data(&ts.to_le_bytes()); + } ret } @@ -58,8 +66,23 @@ pub(crate) fn decode_ikm_list(data: &str) -> Result { }) } -pub(crate) fn decode_cipher(data: &str) -> Result<(IkmId, EncryptedData)> { - let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect(); +pub(crate) fn decode_cipher(data: &str) -> Result<(IkmId, EncryptedData, Option)> { + let mut v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect(); + let ts = if v.len() == NB_PARTS + 1 { + match v.pop() { + Some(ts_raw) => { + let ts_raw = decode_data(ts_raw)?; + let ts_raw: [u8; 8] = ts_raw + .clone() + .try_into() + .map_err(|_| Error::ParsingEncodedDataInvalidTimestamp(ts_raw))?; + Some(u64::from_le_bytes(ts_raw)) + } + None => None, + } + } else { + None + }; if v.len() != NB_PARTS { return Err(Error::ParsingEncodedDataInvalidPartLen(NB_PARTS, v.len())); } @@ -73,7 +96,7 @@ pub(crate) fn decode_cipher(data: &str) -> Result<(IkmId, EncryptedData)> { nonce: decode_data(v[1])?, ciphertext: decode_data(v[2])?, }; - Ok((id, encrypted_data)) + Ok((id, encrypted_data, ts)) } #[cfg(test)] @@ -99,7 +122,7 @@ mod tests { nonce: TEST_NONCE.into(), ciphertext: TEST_CIPHERTEXT.into(), }; - let s = super::encode_cipher(TEST_IKM_ID, &data); + let s = super::encode_cipher(TEST_IKM_ID, &data, None); assert_eq!(&s, TEST_STR); } @@ -107,10 +130,11 @@ mod tests { fn decode() { let res = super::decode_cipher(TEST_STR); assert!(res.is_ok(), "res: {res:?}"); - let (id, data) = res.unwrap(); + let (id, data, ts) = res.unwrap(); assert_eq!(id, TEST_IKM_ID); assert_eq!(data.nonce, TEST_NONCE); assert_eq!(data.ciphertext, TEST_CIPHERTEXT); + assert_eq!(ts, None); } #[test] @@ -119,17 +143,18 @@ mod tests { nonce: TEST_NONCE.into(), ciphertext: TEST_CIPHERTEXT.into(), }; - let s = super::encode_cipher(TEST_IKM_ID, &data); - let (id, decoded_data) = super::decode_cipher(&s).unwrap(); + let s = super::encode_cipher(TEST_IKM_ID, &data, None); + let (id, decoded_data, ts) = super::decode_cipher(&s).unwrap(); assert_eq!(id, TEST_IKM_ID); assert_eq!(decoded_data.nonce, data.nonce); assert_eq!(decoded_data.ciphertext, data.ciphertext); + assert_eq!(ts, None); } #[test] fn decode_encode() { - let (id, data) = super::decode_cipher(TEST_STR).unwrap(); - let s = super::encode_cipher(id, &data); + let (id, data, ts) = super::decode_cipher(TEST_STR).unwrap(); + let s = super::encode_cipher(id, &data, ts); assert_eq!(&s, TEST_STR); } }