Replace the encrypt and decrypt function by the CipherBox struct

This commit is contained in:
Rodolphe Bréard 2024-03-17 14:23:03 +01:00
parent 47557fe350
commit 749dc03f71
5 changed files with 84 additions and 69 deletions

View file

@ -32,6 +32,15 @@ pub(crate) struct EncryptedData {
pub(crate) ciphertext: Vec<u8>, pub(crate) ciphertext: Vec<u8>,
} }
pub struct CipherBox<'a> {
ikm_list: &'a InputKeyMaterialList,
}
impl<'a> CipherBox<'a> {
pub fn new(ikm_list: &'a InputKeyMaterialList) -> Self {
Self { ikm_list }
}
#[inline] #[inline]
fn generate_aad( fn generate_aad(
ikm_id: IkmId, ikm_id: IkmId,
@ -54,7 +63,7 @@ fn generate_aad(
} }
pub fn encrypt( pub fn encrypt(
ikml: &InputKeyMaterialList, &self,
key_context: &KeyContext, key_context: &KeyContext,
data: impl AsRef<[u8]>, data: impl AsRef<[u8]>,
data_context: &DataContext, data_context: &DataContext,
@ -65,29 +74,30 @@ pub fn encrypt(
} else { } else {
None None
}; };
let ikm = ikml.get_latest_ikm()?; let ikm = self.ikm_list.get_latest_ikm()?;
let key = derive_key(ikm, key_context, tp); let key = derive_key(ikm, key_context, tp);
let gen_nonce_function = ikm.scheme.get_gen_nonce(); let gen_nonce_function = ikm.scheme.get_gen_nonce();
let nonce = gen_nonce_function()?; let nonce = gen_nonce_function()?;
let aad = generate_aad(ikm.id, &nonce, key_context, data_context, tp); let aad = Self::generate_aad(ikm.id, &nonce, key_context, data_context, tp);
let encryption_function = ikm.scheme.get_encryption(); let encryption_function = ikm.scheme.get_encryption();
let encrypted_data = encryption_function(&key, &nonce, data.as_ref(), &aad)?; let encrypted_data = encryption_function(&key, &nonce, data.as_ref(), &aad)?;
Ok(storage::encode_cipher(ikm.id, &encrypted_data, tp)) Ok(storage::encode_cipher(ikm.id, &encrypted_data, tp))
} }
pub fn decrypt( pub fn decrypt(
ikml: &InputKeyMaterialList, &self,
key_context: &KeyContext, key_context: &KeyContext,
stored_data: &str, stored_data: &str,
data_context: &DataContext, data_context: &DataContext,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?; let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?;
let ikm = ikml.get_ikm_by_id(ikm_id)?; let ikm = self.ikm_list.get_ikm_by_id(ikm_id)?;
let key = derive_key(ikm, key_context, tp); let key = derive_key(ikm, key_context, tp);
let aad = generate_aad(ikm.id, &encrypted_data.nonce, key_context, data_context, tp); let aad = Self::generate_aad(ikm.id, &encrypted_data.nonce, key_context, data_context, tp);
let decryption_function = ikm.scheme.get_decryption(); let decryption_function = ikm.scheme.get_decryption();
decryption_function(&key, &encrypted_data, &aad) decryption_function(&key, &encrypted_data, &aad)
} }
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -120,19 +130,20 @@ mod tests {
#[test] #[test]
fn encrypt_decrypt_no_context() { fn encrypt_decrypt_no_context() {
let lst = get_ikm_lst();
let key_ctx = get_static_empty_key_ctx(); let key_ctx = get_static_empty_key_ctx();
let data_ctx = DataContext::from([]); let data_ctx = DataContext::from([]);
let cb = CipherBox::new(&lst);
// Encrypt // Encrypt
let lst = get_ikm_lst(); let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx);
let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 98); assert_eq!(ciphertext.len(), 98);
// Decrypt // Decrypt
let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);
@ -143,16 +154,17 @@ mod tests {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = get_static_key_ctx(); let key_ctx = get_static_key_ctx();
let data_ctx = DataContext::from(TEST_DATA_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX);
let cb = CipherBox::new(&lst);
// Encrypt // Encrypt
let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 98); assert_eq!(ciphertext.len(), 98);
// Decrypt // Decrypt
let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);
@ -163,16 +175,17 @@ mod tests {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = KeyContext::from(TEST_KEY_CTX); let key_ctx = KeyContext::from(TEST_KEY_CTX);
let data_ctx = DataContext::from(TEST_DATA_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX);
let cb = CipherBox::new(&lst);
// Encrypt // Encrypt
let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx); let res = cb.encrypt(&key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 110); assert_eq!(ciphertext.len(), 110);
// Decrypt // Decrypt
let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx); let res = cb.decrypt(&key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);
@ -193,14 +206,15 @@ mod tests {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = KeyContext::from(TEST_KEY_CTX); let key_ctx = KeyContext::from(TEST_KEY_CTX);
let data_ctx = DataContext::from(TEST_DATA_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX);
let cb = CipherBox::new(&lst);
// Test if the reference ciphertext used for the tests is actually valid // Test if the reference ciphertext used for the tests is actually valid
let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &data_ctx); let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &data_ctx);
assert!(res.is_ok(), "invalid reference ciphertext"); assert!(res.is_ok(), "invalid reference ciphertext");
// Test if altered versions of the reference ciphertext are refused // Test if altered versions of the reference ciphertext are refused
for (ciphertext, error_str) in tests { for (ciphertext, error_str) in tests {
let res = decrypt(&lst, &key_ctx, ciphertext, &data_ctx); let res = cb.decrypt(&key_ctx, ciphertext, &data_ctx);
assert!(res.is_err(), "failed error detection: {error_str}"); assert!(res.is_err(), "failed error detection: {error_str}");
} }
} }
@ -210,16 +224,17 @@ mod tests {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = KeyContext::from(TEST_KEY_CTX); let key_ctx = KeyContext::from(TEST_KEY_CTX);
let data_ctx = DataContext::from(TEST_DATA_CTX); let data_ctx = DataContext::from(TEST_DATA_CTX);
let cb = CipherBox::new(&lst);
let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &data_ctx); let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &data_ctx);
assert!(res.is_ok(), "invalid reference ciphertext"); assert!(res.is_ok(), "invalid reference ciphertext");
let invalid_key_ctx = KeyContext::from(["invalid", "key", "context"]); let invalid_key_ctx = KeyContext::from(["invalid", "key", "context"]);
let res = decrypt(&lst, &invalid_key_ctx, TEST_CIPHERTEXT, &data_ctx); let res = cb.decrypt(&invalid_key_ctx, TEST_CIPHERTEXT, &data_ctx);
assert!(res.is_err(), "failed error detection: invalid key context"); assert!(res.is_err(), "failed error detection: invalid key context");
let invalid_data_ctx = DataContext::from(["invalid", "data", "context"]); let invalid_data_ctx = DataContext::from(["invalid", "data", "context"]);
let res = decrypt(&lst, &key_ctx, TEST_CIPHERTEXT, &invalid_data_ctx); let res = cb.decrypt(&key_ctx, TEST_CIPHERTEXT, &invalid_data_ctx);
assert!(res.is_err(), "failed error detection: invalid key context"); assert!(res.is_err(), "failed error detection: invalid key context");
} }
} }

View file

@ -1,7 +1,7 @@
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
mod canonicalization; mod canonicalization;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
mod encryption; mod cipher_box;
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]
mod error; mod error;
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]
@ -14,7 +14,7 @@ mod scheme;
mod storage; mod storage;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
pub use encryption::{decrypt, encrypt, DataContext}; pub use cipher_box::{CipherBox, DataContext};
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]
pub use error::Error; pub use error::Error;
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]

View file

@ -1,5 +1,5 @@
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::encryption::{DecryptionFunction, EncryptionFunction, GenNonceFunction}; use crate::cipher_box::{DecryptionFunction, EncryptionFunction, GenNonceFunction};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::kdf::KdfFunction; use crate::kdf::KdfFunction;
use crate::Error; use crate::Error;

View file

@ -1,4 +1,4 @@
use crate::encryption::EncryptedData; use crate::cipher_box::EncryptedData;
use crate::error::Result; use crate::error::Result;
use chacha20poly1305::aead::{Aead, KeyInit, Payload}; use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce}; use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};

View file

@ -1,5 +1,5 @@
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::encryption::EncryptedData; use crate::cipher_box::EncryptedData;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::ikm::IkmId; use crate::ikm::IkmId;