From 423476c987effa69a8e3a3f1d19b5f2889e8841e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodolphe=20Br=C3=A9ard?= Date: Sat, 2 Mar 2024 11:00:59 +0100 Subject: [PATCH] Refactor the IKM storage format --- src/encryption.rs | 6 ++-- src/error.rs | 6 ++-- src/ikm.rs | 92 ++++++++++++++++++++++++----------------------- src/kdf.rs | 2 +- src/scheme.rs | 6 ++++ src/storage.rs | 50 +++++++++++++++++++++----- 6 files changed, 102 insertions(+), 60 deletions(-) diff --git a/src/encryption.rs b/src/encryption.rs index b13ee85..55891e5 100644 --- a/src/encryption.rs +++ b/src/encryption.rs @@ -29,7 +29,7 @@ pub fn encrypt( let aad = generate_aad(key_context, data_context); let encryption_function = ikm.scheme.get_encryption(); let encrypted_data = encryption_function(&key, data.as_ref(), &aad)?; - Ok(storage::encode(ikm.id, &encrypted_data)) + Ok(storage::encode_cipher(ikm.id, &encrypted_data)) } pub fn decrypt( @@ -38,7 +38,7 @@ pub fn decrypt( stored_data: &str, data_context: &[impl AsRef<[u8]>], ) -> Result> { - let (ikm_id, encrypted_data) = storage::decode(stored_data)?; + let (ikm_id, encrypted_data) = storage::decode_cipher(stored_data)?; let ikm = ikml.get_ikm_by_id(ikm_id)?; let key = derive_key(ikm, key_context); let aad = generate_aad(key_context, data_context); @@ -57,7 +57,7 @@ mod tests { fn get_ikm_lst() -> InputKeyMaterialList { InputKeyMaterialList::import( - "AQAAAAEAAAABAAAANGFtbdYEN0s7dzCfMm7dYeQWD64GdmuKsYSiKwppAhmkz81lAAAAACQDr2cAAAAAAA", + "AQAAAA:AQAAAAEAAAC_vYEw1ujVG5i-CtoPYSzik_6xaAq59odjPm5ij01-e6zz4mUAAAAALJGBiwAAAAAA", ) .unwrap() } diff --git a/src/error.rs b/src/error.rs index 8a4a206..a9919ba 100644 --- a/src/error.rs +++ b/src/error.rs @@ -15,10 +15,12 @@ pub enum Error { ParsingBase64Error(base64ct::Error), #[error("parsing error: encoded data: invalid IKM id: {0:?}")] ParsingEncodedDataInvalidIkmId(Vec), + #[error("parsing error: encoded data: invalid IKM length{0}")] + ParsingEncodedDataInvalidIkmLen(usize), + #[error("parsing error: encoded data: invalid IKM list length{0}")] + ParsingEncodedDataInvalidIkmListLen(usize), #[error("parsing error: encoded data: invalid number of parts: got {1} instead of {0}")] ParsingEncodedDataInvalidPartLen(usize, usize), - #[error("parsing error: ikm: invalid data length: {0} bytes")] - ParsingIkmInvalidLength(usize), #[error("parsing error: scheme: {0}: unknown scheme")] ParsingSchemeUnknownScheme(SchemeSerializeType), #[error("unable to generate random values: {0}")] diff --git a/src/ikm.rs b/src/ikm.rs index fb8f544..d936ae7 100644 --- a/src/ikm.rs +++ b/src/ikm.rs @@ -1,10 +1,8 @@ use crate::error::{Error, Result}; use crate::scheme::{Scheme, SchemeSerializeType}; -use base64ct::{Base64UrlUnpadded, Encoding}; use std::time::{Duration, SystemTime}; -const IKM_STRUCT_SIZE: usize = 57; -const IKM_CONTENT_SIZE: usize = 32; +pub(crate) const IKM_BASE_STRUCT_SIZE: usize = 25; pub(crate) type CounterId = u32; pub type IkmId = u32; @@ -13,7 +11,7 @@ pub type IkmId = u32; pub struct InputKeyMaterial { pub id: IkmId, pub scheme: Scheme, - pub(crate) content: [u8; IKM_CONTENT_SIZE], + pub(crate) content: Vec, pub created_at: SystemTime, pub expire_at: SystemTime, pub is_revoked: bool, @@ -21,8 +19,8 @@ pub struct InputKeyMaterial { impl InputKeyMaterial { #[cfg(feature = "ikm-management")] - fn as_bytes(&self) -> Result<[u8; IKM_STRUCT_SIZE]> { - let mut res = Vec::with_capacity(IKM_STRUCT_SIZE); + pub(crate) fn as_bytes(&self) -> Result> { + let mut res = Vec::with_capacity(IKM_BASE_STRUCT_SIZE + self.scheme.get_ikm_size()); res.extend_from_slice(&self.id.to_le_bytes()); res.extend_from_slice(&(self.scheme as SchemeSerializeType).to_le_bytes()); res.extend_from_slice(&self.content); @@ -41,17 +39,26 @@ impl InputKeyMaterial { .to_le_bytes(), ); res.push(self.is_revoked as u8); - Ok(res.try_into().unwrap()) + Ok(res) } - pub(crate) fn from_bytes(b: [u8; IKM_STRUCT_SIZE]) -> Result { + pub(crate) fn from_bytes(b: &[u8]) -> Result { + if b.len() < IKM_BASE_STRUCT_SIZE { + return Err(Error::ParsingEncodedDataInvalidIkmLen(b.len())); + } + let scheme: Scheme = + SchemeSerializeType::from_le_bytes(b[4..8].try_into().unwrap()).try_into()?; + let is = scheme.get_ikm_size(); + if b.len() != IKM_BASE_STRUCT_SIZE + is { + return Err(Error::ParsingEncodedDataInvalidIkmLen(b.len())); + } Ok(Self { id: IkmId::from_le_bytes(b[0..4].try_into().unwrap()), - scheme: SchemeSerializeType::from_le_bytes(b[4..8].try_into().unwrap()).try_into()?, - content: b[8..40].try_into().unwrap(), - created_at: InputKeyMaterial::bytes_to_system_time(&b[40..48])?, - expire_at: InputKeyMaterial::bytes_to_system_time(&b[48..56])?, - is_revoked: b[56] != 0, + scheme, + content: b[8..8 + is].into(), + created_at: InputKeyMaterial::bytes_to_system_time(&b[8 + is..8 + is + 8])?, + expire_at: InputKeyMaterial::bytes_to_system_time(&b[8 + is + 8..8 + is + 8 + 8])?, + is_revoked: b[8 + is + 8 + 8] != 0, }) } @@ -66,8 +73,8 @@ impl InputKeyMaterial { #[derive(Debug, Default)] pub struct InputKeyMaterialList { - ikm_lst: Vec, - id_counter: CounterId, + pub(crate) ikm_lst: Vec, + pub(crate) id_counter: CounterId, } impl InputKeyMaterialList { @@ -78,18 +85,22 @@ impl InputKeyMaterialList { #[cfg(feature = "ikm-management")] pub fn add_ikm(&mut self) -> Result<()> { - self.add_ikm_with_duration(Duration::from_secs(crate::DEFAULT_IKM_DURATION)) + self.add_custom_ikm( + crate::DEFAULT_SCHEME, + Duration::from_secs(crate::DEFAULT_IKM_DURATION), + ) } #[cfg(feature = "ikm-management")] - pub fn add_ikm_with_duration(&mut self, duration: Duration) -> Result<()> { - let mut content: [u8; 32] = [0; 32]; - getrandom::getrandom(&mut content)?; + pub fn add_custom_ikm(&mut self, scheme: Scheme, duration: Duration) -> Result<()> { + let ikm_len = scheme.get_ikm_size(); + let mut content: Vec = vec![0; ikm_len]; + getrandom::getrandom(content.as_mut_slice())?; let created_at = SystemTime::now(); self.id_counter += 1; self.ikm_lst.push(InputKeyMaterial { id: self.id_counter, - scheme: crate::DEFAULT_SCHEME, + scheme, created_at, expire_at: created_at + duration, is_revoked: false, @@ -116,28 +127,11 @@ impl InputKeyMaterialList { #[cfg(feature = "ikm-management")] pub fn export(&self) -> Result { - let data_size = (self.ikm_lst.len() * IKM_STRUCT_SIZE) + 4; - let mut data = Vec::with_capacity(data_size); - data.extend_from_slice(&self.id_counter.to_le_bytes()); - for ikm in &self.ikm_lst { - data.extend_from_slice(&ikm.as_bytes()?); - } - Ok(Base64UrlUnpadded::encode_string(&data)) + crate::storage::encode_ikm_list(self) } pub fn import(s: &str) -> Result { - let data = Base64UrlUnpadded::decode_vec(s)?; - if data.len() % IKM_STRUCT_SIZE != 4 { - return Err(Error::ParsingIkmInvalidLength(data.len())); - } - let mut ikm_lst = Vec::with_capacity(data.len() / IKM_STRUCT_SIZE); - for ikm_slice in data[4..].chunks_exact(IKM_STRUCT_SIZE) { - ikm_lst.push(InputKeyMaterial::from_bytes(ikm_slice.try_into().unwrap())?); - } - Ok(Self { - ikm_lst, - id_counter: CounterId::from_le_bytes(data[0..4].try_into().unwrap()), - }) + crate::storage::decode_ikm_list(s) } #[cfg(feature = "encryption")] @@ -195,7 +189,10 @@ mod tests { assert_eq!(el.id, 1); assert_eq!(el.is_revoked, false); - let res = lst.add_ikm(); + let res = lst.add_custom_ikm( + Scheme::XChaCha20Poly1305WithBlake3, + Duration::from_secs(crate::DEFAULT_IKM_DURATION), + ); assert!(res.is_ok()); assert_eq!(lst.id_counter, 2); assert_eq!(lst.ikm_lst.len(), 2); @@ -228,13 +225,13 @@ mod tests { let res = lst.export(); assert!(res.is_ok()); let s = res.unwrap(); - assert_eq!(s.len(), 82); + assert_eq!(s.len(), 83); } #[test] fn import() { let s = - "AQAAAAEAAAABAAAANGFtbdYEN0s7dzCfMm7dYeQWD64GdmuKsYSiKwppAhmkz81lAAAAACQDr2cAAAAAAA"; + "AQAAAA:AQAAAAEAAAC_vYEw1ujVG5i-CtoPYSzik_6xaAq59odjPm5ij01-e6zz4mUAAAAALJGBiwAAAAAA"; let res = InputKeyMaterialList::import(s); assert!(res.is_ok()); let lst = res.unwrap(); @@ -246,8 +243,8 @@ mod tests { assert_eq!( ikm.content, [ - 52, 97, 109, 109, 214, 4, 55, 75, 59, 119, 48, 159, 50, 110, 221, 97, 228, 22, 15, - 174, 6, 118, 107, 138, 177, 132, 162, 43, 10, 105, 2, 25 + 191, 189, 129, 48, 214, 232, 213, 27, 152, 190, 10, 218, 15, 97, 44, 226, 147, 254, + 177, 104, 10, 185, 246, 135, 99, 62, 110, 98, 143, 77, 126, 123 ] ); assert_eq!(ikm.is_revoked, false); @@ -360,11 +357,16 @@ mod tests { let mut lst = InputKeyMaterialList::new(); let _ = lst.add_ikm(); let _ = lst.add_ikm(); - let _ = lst.add_ikm(); + let _ = lst.add_custom_ikm( + Scheme::XChaCha20Poly1305WithBlake3, + Duration::from_secs(crate::DEFAULT_IKM_DURATION), + ); let res = lst.get_latest_ikm(); assert!(res.is_ok()); let latest_ikm = res.unwrap(); assert_eq!(latest_ikm.id, 3); + assert_eq!(latest_ikm.scheme, Scheme::XChaCha20Poly1305WithBlake3); + assert_eq!(latest_ikm.content.len(), 32); } #[test] diff --git a/src/kdf.rs b/src/kdf.rs index 4b88d0f..0ffc721 100644 --- a/src/kdf.rs +++ b/src/kdf.rs @@ -22,7 +22,7 @@ mod tests { 0xd0, 0x65, 0x00, 0x00, 0x00, 0x00, 0x3d, 0x82, 0x6f, 0x8b, 0x00, 0x00, 0x00, 0x00, 0x00, ]; - let ikm = InputKeyMaterial::from_bytes(ikm_raw).unwrap(); + let ikm = InputKeyMaterial::from_bytes(&ikm_raw).unwrap(); let ctx = ["some", "context"]; assert_eq!( diff --git a/src/scheme.rs b/src/scheme.rs index 232bd70..8c602fe 100644 --- a/src/scheme.rs +++ b/src/scheme.rs @@ -13,6 +13,12 @@ pub enum Scheme { } impl Scheme { + pub(crate) fn get_ikm_size(&self) -> usize { + match self { + Scheme::XChaCha20Poly1305WithBlake3 => 32, + } + } + pub(crate) fn get_kdf(&self) -> Box { match self { Scheme::XChaCha20Poly1305WithBlake3 => Box::new(blake3::blake3_derive), diff --git a/src/storage.rs b/src/storage.rs index 7ae5759..3167e44 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,6 +1,6 @@ use crate::encryption::EncryptedData; use crate::error::{Error, Result}; -use crate::ikm::IkmId; +use crate::ikm::{CounterId, IkmId, InputKeyMaterial, InputKeyMaterialList, IKM_BASE_STRUCT_SIZE}; use base64ct::{Base64UrlUnpadded, Encoding}; const STORAGE_SEPARATOR: &str = ":"; @@ -16,7 +16,20 @@ fn decode_data(s: &str) -> Result> { Ok(Base64UrlUnpadded::decode_vec(s)?) } -pub(crate) fn encode(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String { +pub(crate) fn encode_ikm_list(ikml: &InputKeyMaterialList) -> Result { + let data_size = (ikml.ikm_lst.iter().fold(0, |acc, ikm| { + acc + IKM_BASE_STRUCT_SIZE + ikm.scheme.get_ikm_size() + })) + 4; + let mut ret = String::with_capacity(data_size); + ret += &encode_data(&ikml.id_counter.to_le_bytes()); + for ikm in &ikml.ikm_lst { + ret += STORAGE_SEPARATOR; + ret += &encode_data(&ikm.as_bytes()?); + } + Ok(ret) +} + +pub(crate) fn encode_cipher(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String { let mut ret = String::new(); ret += &encode_data(&ikm_id.to_le_bytes()); ret += STORAGE_SEPARATOR; @@ -26,7 +39,26 @@ pub(crate) fn encode(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String { ret } -pub(crate) fn decode(data: &str) -> Result<(IkmId, EncryptedData)> { +pub(crate) fn decode_ikm_list(data: &str) -> Result { + let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect(); + if v.is_empty() { + return Err(Error::ParsingEncodedDataInvalidIkmListLen(v.len())); + } + let id_data = decode_data(v[0])?; + let id_counter = CounterId::from_le_bytes(id_data[0..4].try_into().unwrap()); + let mut ikm_lst = Vec::with_capacity(v.len() - 1); + for ikm_str in &v[1..] { + let raw_ikm = decode_data(ikm_str)?; + let ikm = InputKeyMaterial::from_bytes(&raw_ikm)?; + ikm_lst.push(ikm); + } + Ok(InputKeyMaterialList { + ikm_lst, + id_counter, + }) +} + +pub(crate) fn decode_cipher(data: &str) -> Result<(IkmId, EncryptedData)> { let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect(); if v.len() != NB_PARTS { return Err(Error::ParsingEncodedDataInvalidPartLen(NB_PARTS, v.len())); @@ -67,13 +99,13 @@ mod tests { nonce: TEST_NONCE.into(), ciphertext: TEST_CIPHERTEXT.into(), }; - let s = super::encode(TEST_IKM_ID, &data); + let s = super::encode_cipher(TEST_IKM_ID, &data); assert_eq!(&s, TEST_STR); } #[test] fn decode() { - let res = super::decode(TEST_STR); + let res = super::decode_cipher(TEST_STR); assert!(res.is_ok()); let (id, data) = res.unwrap(); assert_eq!(id, TEST_IKM_ID); @@ -87,8 +119,8 @@ mod tests { nonce: TEST_NONCE.into(), ciphertext: TEST_CIPHERTEXT.into(), }; - let s = super::encode(TEST_IKM_ID, &data); - let (id, decoded_data) = super::decode(&s).unwrap(); + let s = super::encode_cipher(TEST_IKM_ID, &data); + let (id, decoded_data) = super::decode_cipher(&s).unwrap(); assert_eq!(id, TEST_IKM_ID); assert_eq!(decoded_data.nonce, data.nonce); assert_eq!(decoded_data.ciphertext, data.ciphertext); @@ -96,8 +128,8 @@ mod tests { #[test] fn decode_encode() { - let (id, data) = super::decode(TEST_STR).unwrap(); - let s = super::encode(id, &data); + let (id, data) = super::decode_cipher(TEST_STR).unwrap(); + let s = super::encode_cipher(id, &data); assert_eq!(&s, TEST_STR); } }