diff --git a/src/encryption.rs b/src/encryption.rs index 6a4b4a7..0b93e77 100644 --- a/src/encryption.rs +++ b/src/encryption.rs @@ -7,6 +7,24 @@ use std::time::{SystemTime, UNIX_EPOCH}; pub(crate) type DecryptionFunction = dyn Fn(&[u8], &EncryptedData, &str) -> Result>; pub(crate) type EncryptionFunction = dyn Fn(&[u8], &[u8], &str) -> Result; +pub struct DataContext { + ctx: Vec, +} + +impl DataContext { + pub(crate) fn get_ctx_elems(&self) -> &[String] { + self.ctx.as_ref() + } +} + +impl From<[&str; N]> for DataContext { + fn from(ctx: [&str; N]) -> Self { + Self { + ctx: ctx.iter().map(|s| s.to_string()).collect(), + } + } +} + #[derive(Debug)] pub(crate) struct EncryptedData { pub(crate) nonce: Vec, @@ -16,12 +34,12 @@ pub(crate) struct EncryptedData { #[inline] fn generate_aad( key_context: &KeyContext, - data_context: &[impl AsRef<[u8]>], + data_context: &DataContext, time_period: Option, ) -> String { let elems = key_context.get_ctx_elems(time_period); let key_context_canon = canonicalize(&elems); - let data_context_canon = canonicalize(data_context); + let data_context_canon = canonicalize(data_context.get_ctx_elems()); join_canonicalized_str(&key_context_canon, &data_context_canon) } @@ -29,7 +47,7 @@ pub fn encrypt( ikml: &InputKeyMaterialList, key_context: &KeyContext, data: impl AsRef<[u8]>, - data_context: &[impl AsRef<[u8]>], + data_context: &DataContext, ) -> Result { let tp = if key_context.is_periodic() { let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); @@ -49,7 +67,7 @@ pub fn decrypt( ikml: &InputKeyMaterialList, key_context: &KeyContext, stored_data: &str, - data_context: &[impl AsRef<[u8]>], + data_context: &DataContext, ) -> Result> { let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?; let ikm = ikml.get_ikm_by_id(ikm_id)?; @@ -62,12 +80,11 @@ pub fn decrypt( #[cfg(test)] mod tests { use super::*; - use crate::KeyContext; + use crate::{DataContext, KeyContext}; const TEST_DATA: &[u8] = b"Lorem ipsum dolor sit amet."; const TEST_KEY_CTX: [&str; 3] = ["db_name", "table_name", "column_name"]; - const TEST_DATA_CTX: &[&str] = &["018db876-3d9d-79af-9460-55d17da991d8"]; - const EMPTY_DATA_CTX: &[[u8; 0]] = &[]; + const TEST_DATA_CTX: [&str; 1] = ["018db876-3d9d-79af-9460-55d17da991d8"]; fn get_static_key_ctx() -> KeyContext { let mut ctx: KeyContext = TEST_KEY_CTX.into(); @@ -90,18 +107,19 @@ mod tests { #[test] fn encrypt_decrypt_no_context() { - let ctx = get_static_empty_key_ctx(); + let key_ctx = get_static_empty_key_ctx(); + let data_ctx = DataContext::from([]); // Encrypt let lst = get_ikm_lst(); - let res = encrypt(&lst, &ctx, TEST_DATA, EMPTY_DATA_CTX); + let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, &ctx, &ciphertext, EMPTY_DATA_CTX); + let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -111,16 +129,17 @@ mod tests { fn encrypt_decrypt_with_static_context() { let lst = get_ikm_lst(); let key_ctx = get_static_key_ctx(); + let data_ctx = DataContext::from(TEST_DATA_CTX); // Encrypt - let res = encrypt(&lst, &key_ctx, TEST_DATA, TEST_DATA_CTX); + let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, &key_ctx, &ciphertext, TEST_DATA_CTX); + let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -130,16 +149,17 @@ mod tests { fn encrypt_decrypt_with_context() { let lst = get_ikm_lst(); let key_ctx = KeyContext::from(TEST_KEY_CTX); + let data_ctx = DataContext::from(TEST_DATA_CTX); // Encrypt - let res = encrypt(&lst, &key_ctx, TEST_DATA, TEST_DATA_CTX); + let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 110); // Decrypt - let res = decrypt(&lst, &key_ctx, &ciphertext, TEST_DATA_CTX); + let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); diff --git a/src/lib.rs b/src/lib.rs index 4b1d5f4..367cc05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod scheme; mod storage; #[cfg(feature = "encryption")] -pub use encryption::{decrypt, encrypt}; +pub use encryption::{decrypt, encrypt, DataContext}; #[cfg(any(feature = "encryption", feature = "ikm-management"))] pub use error::Error; #[cfg(any(feature = "encryption", feature = "ikm-management"))]