From 749dc03f711e681ef42e560a8ab95b4da740689b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodolphe=20Br=C3=A9ard?= Date: Sun, 17 Mar 2024 14:23:03 +0100 Subject: [PATCH] Replace the encrypt and decrypt function by the CipherBox struct --- src/{encryption.rs => cipher_box.rs} | 143 +++++++++++++++------------ src/lib.rs | 4 +- src/scheme.rs | 2 +- src/scheme/xchacha20poly1305.rs | 2 +- src/storage.rs | 2 +- 5 files changed, 84 insertions(+), 69 deletions(-) rename src/{encryption.rs => cipher_box.rs} (64%) diff --git a/src/encryption.rs b/src/cipher_box.rs similarity index 64% rename from src/encryption.rs rename to src/cipher_box.rs index 4bc104c..faed85f 100644 --- a/src/encryption.rs +++ b/src/cipher_box.rs @@ -32,61 +32,71 @@ pub(crate) struct EncryptedData { pub(crate) ciphertext: Vec, } -#[inline] -fn generate_aad( - ikm_id: IkmId, - nonce: &[u8], - key_context: &KeyContext, - data_context: &DataContext, - time_period: Option, -) -> String { - let ikm_id_canon = canonicalize(&[ikm_id.to_le_bytes()]); - let nonce_canon = canonicalize(&[nonce]); - let elems = key_context.get_ctx_elems(time_period); - let key_context_canon = canonicalize(&elems); - let data_context_canon = canonicalize(data_context.get_ctx_elems()); - join_canonicalized_str(&[ - ikm_id_canon, - nonce_canon, - key_context_canon, - data_context_canon, - ]) +pub struct CipherBox<'a> { + ikm_list: &'a InputKeyMaterialList, } -pub fn encrypt( - ikml: &InputKeyMaterialList, - key_context: &KeyContext, - data: impl AsRef<[u8]>, - data_context: &DataContext, -) -> Result { - let tp = if key_context.is_periodic() { - let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); - key_context.get_time_period(ts) - } else { - None - }; - let ikm = ikml.get_latest_ikm()?; - let key = derive_key(ikm, key_context, tp); - let gen_nonce_function = ikm.scheme.get_gen_nonce(); - let nonce = gen_nonce_function()?; - let aad = generate_aad(ikm.id, &nonce, key_context, data_context, tp); - let encryption_function = ikm.scheme.get_encryption(); - let encrypted_data = encryption_function(&key, &nonce, data.as_ref(), &aad)?; - Ok(storage::encode_cipher(ikm.id, &encrypted_data, tp)) -} +impl<'a> CipherBox<'a> { + pub fn new(ikm_list: &'a InputKeyMaterialList) -> Self { + Self { ikm_list } + } -pub fn decrypt( - ikml: &InputKeyMaterialList, - key_context: &KeyContext, - stored_data: &str, - data_context: &DataContext, -) -> Result> { - let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?; - let ikm = ikml.get_ikm_by_id(ikm_id)?; - let key = derive_key(ikm, key_context, tp); - let aad = generate_aad(ikm.id, &encrypted_data.nonce, key_context, data_context, tp); - let decryption_function = ikm.scheme.get_decryption(); - decryption_function(&key, &encrypted_data, &aad) + #[inline] + fn generate_aad( + ikm_id: IkmId, + nonce: &[u8], + key_context: &KeyContext, + data_context: &DataContext, + time_period: Option, + ) -> String { + let ikm_id_canon = canonicalize(&[ikm_id.to_le_bytes()]); + let nonce_canon = canonicalize(&[nonce]); + let elems = key_context.get_ctx_elems(time_period); + let key_context_canon = canonicalize(&elems); + let data_context_canon = canonicalize(data_context.get_ctx_elems()); + join_canonicalized_str(&[ + ikm_id_canon, + nonce_canon, + key_context_canon, + data_context_canon, + ]) + } + + pub fn encrypt( + &self, + key_context: &KeyContext, + data: impl AsRef<[u8]>, + data_context: &DataContext, + ) -> Result { + let tp = if key_context.is_periodic() { + let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); + key_context.get_time_period(ts) + } else { + None + }; + let ikm = self.ikm_list.get_latest_ikm()?; + let key = derive_key(ikm, key_context, tp); + let gen_nonce_function = ikm.scheme.get_gen_nonce(); + let nonce = gen_nonce_function()?; + let aad = Self::generate_aad(ikm.id, &nonce, key_context, data_context, tp); + let encryption_function = ikm.scheme.get_encryption(); + let encrypted_data = encryption_function(&key, &nonce, data.as_ref(), &aad)?; + Ok(storage::encode_cipher(ikm.id, &encrypted_data, tp)) + } + + pub fn decrypt( + &self, + key_context: &KeyContext, + stored_data: &str, + data_context: &DataContext, + ) -> Result> { + let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?; + let ikm = self.ikm_list.get_ikm_by_id(ikm_id)?; + let key = derive_key(ikm, key_context, tp); + let aad = Self::generate_aad(ikm.id, &encrypted_data.nonce, key_context, data_context, tp); + let decryption_function = ikm.scheme.get_decryption(); + decryption_function(&key, &encrypted_data, &aad) + } } #[cfg(test)] @@ -120,19 +130,20 @@ mod tests { #[test] fn encrypt_decrypt_no_context() { + let lst = get_ikm_lst(); let key_ctx = get_static_empty_key_ctx(); let data_ctx = DataContext::from([]); + let cb = CipherBox::new(&lst); // Encrypt - let lst = get_ikm_lst(); - let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); + let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); + let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -143,16 +154,17 @@ mod tests { let lst = get_ikm_lst(); let key_ctx = get_static_key_ctx(); let data_ctx = DataContext::from(TEST_DATA_CTX); + let cb = CipherBox::new(&lst); // Encrypt - let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); + let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 98); // Decrypt - let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); + let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -163,16 +175,17 @@ mod tests { let lst = get_ikm_lst(); let key_ctx = KeyContext::from(TEST_KEY_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX); + let cb = CipherBox::new(&lst); // Encrypt - let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); + let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let ciphertext = res.unwrap(); assert!(ciphertext.starts_with("AQAAAA:")); assert_eq!(ciphertext.len(), 110); // Decrypt - let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); + let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx); assert!(res.is_ok(), "res: {res:?}"); let plaintext = res.unwrap(); assert_eq!(plaintext, TEST_DATA); @@ -193,14 +206,15 @@ mod tests { let lst = get_ikm_lst(); let key_ctx = KeyContext::from(TEST_KEY_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX); + let cb = CipherBox::new(&lst); // Test if the reference ciphertext used for the tests is actually valid - let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &data_ctx); + let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &data_ctx); assert!(res.is_ok(), "invalid reference ciphertext"); // Test if altered versions of the reference ciphertext are refused for (ciphertext, error_str) in tests { - let res = decrypt(&lst, &key_ctx, ciphertext, &data_ctx); + let res = cb.decrypt(&key_ctx, ciphertext, &data_ctx); assert!(res.is_err(), "failed error detection: {error_str}"); } } @@ -210,16 +224,17 @@ mod tests { let lst = get_ikm_lst(); let key_ctx = KeyContext::from(TEST_KEY_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX); + let cb = CipherBox::new(&lst); - let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &data_ctx); + let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &data_ctx); assert!(res.is_ok(), "invalid reference ciphertext"); let invalid_key_ctx = KeyContext::from(["invalid", "key", "context"]); - let res = decrypt(&lst, &invalid_key_ctx, TEST_CIPHERTEXT, &data_ctx); + let res = cb.decrypt(&invalid_key_ctx, TEST_CIPHERTEXT, &data_ctx); assert!(res.is_err(), "failed error detection: invalid key context"); let invalid_data_ctx = DataContext::from(["invalid", "data", "context"]); - let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &invalid_data_ctx); + let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &invalid_data_ctx); assert!(res.is_err(), "failed error detection: invalid key context"); } } diff --git a/src/lib.rs b/src/lib.rs index eb2fff1..1c07dec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #[cfg(feature = "encryption")] mod canonicalization; #[cfg(feature = "encryption")] -mod encryption; +mod cipher_box; #[cfg(any(feature = "encryption", feature = "ikm-management"))] mod error; #[cfg(any(feature = "encryption", feature = "ikm-management"))] @@ -14,7 +14,7 @@ mod scheme; mod storage; #[cfg(feature = "encryption")] -pub use encryption::{decrypt, encrypt, DataContext}; +pub use cipher_box::{CipherBox, DataContext}; #[cfg(any(feature = "encryption", feature = "ikm-management"))] pub use error::Error; #[cfg(any(feature = "encryption", feature = "ikm-management"))] diff --git a/src/scheme.rs b/src/scheme.rs index 2ea18d1..853bc6a 100644 --- a/src/scheme.rs +++ b/src/scheme.rs @@ -1,5 +1,5 @@ #[cfg(feature = "encryption")] -use crate::encryption::{DecryptionFunction, EncryptionFunction, GenNonceFunction}; +use crate::cipher_box::{DecryptionFunction, EncryptionFunction, GenNonceFunction}; #[cfg(feature = "encryption")] use crate::kdf::KdfFunction; use crate::Error; diff --git a/src/scheme/xchacha20poly1305.rs b/src/scheme/xchacha20poly1305.rs index 79fca39..ad196f8 100644 --- a/src/scheme/xchacha20poly1305.rs +++ b/src/scheme/xchacha20poly1305.rs @@ -1,4 +1,4 @@ -use crate::encryption::EncryptedData; +use crate::cipher_box::EncryptedData; use crate::error::Result; use chacha20poly1305::aead::{Aead, KeyInit, Payload}; use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce}; diff --git a/src/storage.rs b/src/storage.rs index e855304..d1f4ee8 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,5 +1,5 @@ #[cfg(feature = "encryption")] -use crate::encryption::EncryptedData; +use crate::cipher_box::EncryptedData; use crate::error::{Error, Result}; #[cfg(feature = "encryption")] use crate::ikm::IkmId;