Use the DataContext type

This commit is contained in:
Rodolphe Bréard 2024-03-09 17:29:55 +01:00
parent 165b197a3a
commit d922297e91
2 changed files with 35 additions and 15 deletions

View file

@ -7,6 +7,24 @@ use std::time::{SystemTime, UNIX_EPOCH};
pub(crate) type DecryptionFunction = dyn Fn(&[u8], &EncryptedData, &str) -> Result<Vec<u8>>; pub(crate) type DecryptionFunction = dyn Fn(&[u8], &EncryptedData, &str) -> Result<Vec<u8>>;
pub(crate) type EncryptionFunction = dyn Fn(&[u8], &[u8], &str) -> Result<EncryptedData>; pub(crate) type EncryptionFunction = dyn Fn(&[u8], &[u8], &str) -> Result<EncryptedData>;
pub struct DataContext {
ctx: Vec<String>,
}
impl DataContext {
pub(crate) fn get_ctx_elems(&self) -> &[String] {
self.ctx.as_ref()
}
}
impl<const N: usize> From<[&str; N]> for DataContext {
fn from(ctx: [&str; N]) -> Self {
Self {
ctx: ctx.iter().map(|s| s.to_string()).collect(),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct EncryptedData { pub(crate) struct EncryptedData {
pub(crate) nonce: Vec<u8>, pub(crate) nonce: Vec<u8>,
@ -16,12 +34,12 @@ pub(crate) struct EncryptedData {
#[inline] #[inline]
fn generate_aad( fn generate_aad(
key_context: &KeyContext, key_context: &KeyContext,
data_context: &[impl AsRef<[u8]>], data_context: &DataContext,
time_period: Option<u64>, time_period: Option<u64>,
) -> String { ) -> String {
let elems = key_context.get_ctx_elems(time_period); let elems = key_context.get_ctx_elems(time_period);
let key_context_canon = canonicalize(&elems); let key_context_canon = canonicalize(&elems);
let data_context_canon = canonicalize(data_context); let data_context_canon = canonicalize(data_context.get_ctx_elems());
join_canonicalized_str(&key_context_canon, &data_context_canon) join_canonicalized_str(&key_context_canon, &data_context_canon)
} }
@ -29,7 +47,7 @@ pub fn encrypt(
ikml: &InputKeyMaterialList, ikml: &InputKeyMaterialList,
key_context: &KeyContext, key_context: &KeyContext,
data: impl AsRef<[u8]>, data: impl AsRef<[u8]>,
data_context: &[impl AsRef<[u8]>], data_context: &DataContext,
) -> Result<String> { ) -> Result<String> {
let tp = if key_context.is_periodic() { let tp = if key_context.is_periodic() {
let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
@ -49,7 +67,7 @@ pub fn decrypt(
ikml: &InputKeyMaterialList, ikml: &InputKeyMaterialList,
key_context: &KeyContext, key_context: &KeyContext,
stored_data: &str, stored_data: &str,
data_context: &[impl AsRef<[u8]>], data_context: &DataContext,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?; let (ikm_id, encrypted_data, tp) = storage::decode_cipher(stored_data)?;
let ikm = ikml.get_ikm_by_id(ikm_id)?; let ikm = ikml.get_ikm_by_id(ikm_id)?;
@ -62,12 +80,11 @@ pub fn decrypt(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::KeyContext; use crate::{DataContext, KeyContext};
const TEST_DATA: &[u8] = b"Lorem ipsum dolor sit amet."; const TEST_DATA: &[u8] = b"Lorem ipsum dolor sit amet.";
const TEST_KEY_CTX: [&str; 3] = ["db_name", "table_name", "column_name"]; const TEST_KEY_CTX: [&str; 3] = ["db_name", "table_name", "column_name"];
const TEST_DATA_CTX: &[&str] = &["018db876-3d9d-79af-9460-55d17da991d8"]; const TEST_DATA_CTX: [&str; 1] = ["018db876-3d9d-79af-9460-55d17da991d8"];
const EMPTY_DATA_CTX: &[[u8; 0]] = &[];
fn get_static_key_ctx() -> KeyContext { fn get_static_key_ctx() -> KeyContext {
let mut ctx: KeyContext = TEST_KEY_CTX.into(); let mut ctx: KeyContext = TEST_KEY_CTX.into();
@ -90,18 +107,19 @@ mod tests {
#[test] #[test]
fn encrypt_decrypt_no_context() { fn encrypt_decrypt_no_context() {
let ctx = get_static_empty_key_ctx(); let key_ctx = get_static_empty_key_ctx();
let data_ctx = DataContext::from([]);
// Encrypt // Encrypt
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let res = encrypt(&lst, &ctx, TEST_DATA, EMPTY_DATA_CTX); let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 98); assert_eq!(ciphertext.len(), 98);
// Decrypt // Decrypt
let res = decrypt(&lst, &ctx, &ciphertext, EMPTY_DATA_CTX); let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);
@ -111,16 +129,17 @@ mod tests {
fn encrypt_decrypt_with_static_context() { fn encrypt_decrypt_with_static_context() {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = get_static_key_ctx(); let key_ctx = get_static_key_ctx();
let data_ctx = DataContext::from(TEST_DATA_CTX);
// Encrypt // Encrypt
let res = encrypt(&lst, &key_ctx, TEST_DATA, TEST_DATA_CTX); let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 98); assert_eq!(ciphertext.len(), 98);
// Decrypt // Decrypt
let res = decrypt(&lst, &key_ctx, &ciphertext, TEST_DATA_CTX); let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);
@ -130,16 +149,17 @@ mod tests {
fn encrypt_decrypt_with_context() { fn encrypt_decrypt_with_context() {
let lst = get_ikm_lst(); let lst = get_ikm_lst();
let key_ctx = KeyContext::from(TEST_KEY_CTX); let key_ctx = KeyContext::from(TEST_KEY_CTX);
let data_ctx = DataContext::from(TEST_DATA_CTX);
// Encrypt // Encrypt
let res = encrypt(&lst, &key_ctx, TEST_DATA, TEST_DATA_CTX); let res = encrypt(&lst, &key_ctx, TEST_DATA, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let ciphertext = res.unwrap(); let ciphertext = res.unwrap();
assert!(ciphertext.starts_with("AQAAAA:")); assert!(ciphertext.starts_with("AQAAAA:"));
assert_eq!(ciphertext.len(), 110); assert_eq!(ciphertext.len(), 110);
// Decrypt // Decrypt
let res = decrypt(&lst, &key_ctx, &ciphertext, TEST_DATA_CTX); let res = decrypt(&lst, &key_ctx, &ciphertext, &data_ctx);
assert!(res.is_ok(), "res: {res:?}"); assert!(res.is_ok(), "res: {res:?}");
let plaintext = res.unwrap(); let plaintext = res.unwrap();
assert_eq!(plaintext, TEST_DATA); assert_eq!(plaintext, TEST_DATA);

View file

@ -13,7 +13,7 @@ mod scheme;
mod storage; mod storage;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
pub use encryption::{decrypt, encrypt}; pub use encryption::{decrypt, encrypt, DataContext};
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]
pub use error::Error; pub use error::Error;
#[cfg(any(feature = "encryption", feature = "ikm-management"))] #[cfg(any(feature = "encryption", feature = "ikm-management"))]