Refactor the IKM storage format

This commit is contained in:
Rodolphe Bréard 2024-03-02 11:00:59 +01:00
parent 349ed79b4c
commit 423476c987
6 changed files with 102 additions and 60 deletions

View file

@ -29,7 +29,7 @@ pub fn encrypt(
let aad = generate_aad(key_context, data_context); let aad = generate_aad(key_context, data_context);
let encryption_function = ikm.scheme.get_encryption(); let encryption_function = ikm.scheme.get_encryption();
let encrypted_data = encryption_function(&key, data.as_ref(), &aad)?; let encrypted_data = encryption_function(&key, data.as_ref(), &aad)?;
Ok(storage::encode(ikm.id, &encrypted_data)) Ok(storage::encode_cipher(ikm.id, &encrypted_data))
} }
pub fn decrypt( pub fn decrypt(
@ -38,7 +38,7 @@ pub fn decrypt(
stored_data: &str, stored_data: &str,
data_context: &[impl AsRef<[u8]>], data_context: &[impl AsRef<[u8]>],
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let (ikm_id, encrypted_data) = storage::decode(stored_data)?; let (ikm_id, encrypted_data) = storage::decode_cipher(stored_data)?;
let ikm = ikml.get_ikm_by_id(ikm_id)?; let ikm = ikml.get_ikm_by_id(ikm_id)?;
let key = derive_key(ikm, key_context); let key = derive_key(ikm, key_context);
let aad = generate_aad(key_context, data_context); let aad = generate_aad(key_context, data_context);
@ -57,7 +57,7 @@ mod tests {
fn get_ikm_lst() -> InputKeyMaterialList { fn get_ikm_lst() -> InputKeyMaterialList {
InputKeyMaterialList::import( InputKeyMaterialList::import(
"AQAAAAEAAAABAAAANGFtbdYEN0s7dzCfMm7dYeQWD64GdmuKsYSiKwppAhmkz81lAAAAACQDr2cAAAAAAA", "AQAAAA:AQAAAAEAAAC_vYEw1ujVG5i-CtoPYSzik_6xaAq59odjPm5ij01-e6zz4mUAAAAALJGBiwAAAAAA",
) )
.unwrap() .unwrap()
} }

View file

@ -15,10 +15,12 @@ pub enum Error {
ParsingBase64Error(base64ct::Error), ParsingBase64Error(base64ct::Error),
#[error("parsing error: encoded data: invalid IKM id: {0:?}")] #[error("parsing error: encoded data: invalid IKM id: {0:?}")]
ParsingEncodedDataInvalidIkmId(Vec<u8>), ParsingEncodedDataInvalidIkmId(Vec<u8>),
#[error("parsing error: encoded data: invalid IKM length{0}")]
ParsingEncodedDataInvalidIkmLen(usize),
#[error("parsing error: encoded data: invalid IKM list length{0}")]
ParsingEncodedDataInvalidIkmListLen(usize),
#[error("parsing error: encoded data: invalid number of parts: got {1} instead of {0}")] #[error("parsing error: encoded data: invalid number of parts: got {1} instead of {0}")]
ParsingEncodedDataInvalidPartLen(usize, usize), ParsingEncodedDataInvalidPartLen(usize, usize),
#[error("parsing error: ikm: invalid data length: {0} bytes")]
ParsingIkmInvalidLength(usize),
#[error("parsing error: scheme: {0}: unknown scheme")] #[error("parsing error: scheme: {0}: unknown scheme")]
ParsingSchemeUnknownScheme(SchemeSerializeType), ParsingSchemeUnknownScheme(SchemeSerializeType),
#[error("unable to generate random values: {0}")] #[error("unable to generate random values: {0}")]

View file

@ -1,10 +1,8 @@
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::scheme::{Scheme, SchemeSerializeType}; use crate::scheme::{Scheme, SchemeSerializeType};
use base64ct::{Base64UrlUnpadded, Encoding};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
const IKM_STRUCT_SIZE: usize = 57; pub(crate) const IKM_BASE_STRUCT_SIZE: usize = 25;
const IKM_CONTENT_SIZE: usize = 32;
pub(crate) type CounterId = u32; pub(crate) type CounterId = u32;
pub type IkmId = u32; pub type IkmId = u32;
@ -13,7 +11,7 @@ pub type IkmId = u32;
pub struct InputKeyMaterial { pub struct InputKeyMaterial {
pub id: IkmId, pub id: IkmId,
pub scheme: Scheme, pub scheme: Scheme,
pub(crate) content: [u8; IKM_CONTENT_SIZE], pub(crate) content: Vec<u8>,
pub created_at: SystemTime, pub created_at: SystemTime,
pub expire_at: SystemTime, pub expire_at: SystemTime,
pub is_revoked: bool, pub is_revoked: bool,
@ -21,8 +19,8 @@ pub struct InputKeyMaterial {
impl InputKeyMaterial { impl InputKeyMaterial {
#[cfg(feature = "ikm-management")] #[cfg(feature = "ikm-management")]
fn as_bytes(&self) -> Result<[u8; IKM_STRUCT_SIZE]> { pub(crate) fn as_bytes(&self) -> Result<Vec<u8>> {
let mut res = Vec::with_capacity(IKM_STRUCT_SIZE); let mut res = Vec::with_capacity(IKM_BASE_STRUCT_SIZE + self.scheme.get_ikm_size());
res.extend_from_slice(&self.id.to_le_bytes()); res.extend_from_slice(&self.id.to_le_bytes());
res.extend_from_slice(&(self.scheme as SchemeSerializeType).to_le_bytes()); res.extend_from_slice(&(self.scheme as SchemeSerializeType).to_le_bytes());
res.extend_from_slice(&self.content); res.extend_from_slice(&self.content);
@ -41,17 +39,26 @@ impl InputKeyMaterial {
.to_le_bytes(), .to_le_bytes(),
); );
res.push(self.is_revoked as u8); res.push(self.is_revoked as u8);
Ok(res.try_into().unwrap()) Ok(res)
} }
pub(crate) fn from_bytes(b: [u8; IKM_STRUCT_SIZE]) -> Result<Self> { pub(crate) fn from_bytes(b: &[u8]) -> Result<Self> {
if b.len() < IKM_BASE_STRUCT_SIZE {
return Err(Error::ParsingEncodedDataInvalidIkmLen(b.len()));
}
let scheme: Scheme =
SchemeSerializeType::from_le_bytes(b[4..8].try_into().unwrap()).try_into()?;
let is = scheme.get_ikm_size();
if b.len() != IKM_BASE_STRUCT_SIZE + is {
return Err(Error::ParsingEncodedDataInvalidIkmLen(b.len()));
}
Ok(Self { Ok(Self {
id: IkmId::from_le_bytes(b[0..4].try_into().unwrap()), id: IkmId::from_le_bytes(b[0..4].try_into().unwrap()),
scheme: SchemeSerializeType::from_le_bytes(b[4..8].try_into().unwrap()).try_into()?, scheme,
content: b[8..40].try_into().unwrap(), content: b[8..8 + is].into(),
created_at: InputKeyMaterial::bytes_to_system_time(&b[40..48])?, created_at: InputKeyMaterial::bytes_to_system_time(&b[8 + is..8 + is + 8])?,
expire_at: InputKeyMaterial::bytes_to_system_time(&b[48..56])?, expire_at: InputKeyMaterial::bytes_to_system_time(&b[8 + is + 8..8 + is + 8 + 8])?,
is_revoked: b[56] != 0, is_revoked: b[8 + is + 8 + 8] != 0,
}) })
} }
@ -66,8 +73,8 @@ impl InputKeyMaterial {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct InputKeyMaterialList { pub struct InputKeyMaterialList {
ikm_lst: Vec<InputKeyMaterial>, pub(crate) ikm_lst: Vec<InputKeyMaterial>,
id_counter: CounterId, pub(crate) id_counter: CounterId,
} }
impl InputKeyMaterialList { impl InputKeyMaterialList {
@ -78,18 +85,22 @@ impl InputKeyMaterialList {
#[cfg(feature = "ikm-management")] #[cfg(feature = "ikm-management")]
pub fn add_ikm(&mut self) -> Result<()> { pub fn add_ikm(&mut self) -> Result<()> {
self.add_ikm_with_duration(Duration::from_secs(crate::DEFAULT_IKM_DURATION)) self.add_custom_ikm(
crate::DEFAULT_SCHEME,
Duration::from_secs(crate::DEFAULT_IKM_DURATION),
)
} }
#[cfg(feature = "ikm-management")] #[cfg(feature = "ikm-management")]
pub fn add_ikm_with_duration(&mut self, duration: Duration) -> Result<()> { pub fn add_custom_ikm(&mut self, scheme: Scheme, duration: Duration) -> Result<()> {
let mut content: [u8; 32] = [0; 32]; let ikm_len = scheme.get_ikm_size();
getrandom::getrandom(&mut content)?; let mut content: Vec<u8> = vec![0; ikm_len];
getrandom::getrandom(content.as_mut_slice())?;
let created_at = SystemTime::now(); let created_at = SystemTime::now();
self.id_counter += 1; self.id_counter += 1;
self.ikm_lst.push(InputKeyMaterial { self.ikm_lst.push(InputKeyMaterial {
id: self.id_counter, id: self.id_counter,
scheme: crate::DEFAULT_SCHEME, scheme,
created_at, created_at,
expire_at: created_at + duration, expire_at: created_at + duration,
is_revoked: false, is_revoked: false,
@ -116,28 +127,11 @@ impl InputKeyMaterialList {
#[cfg(feature = "ikm-management")] #[cfg(feature = "ikm-management")]
pub fn export(&self) -> Result<String> { pub fn export(&self) -> Result<String> {
let data_size = (self.ikm_lst.len() * IKM_STRUCT_SIZE) + 4; crate::storage::encode_ikm_list(self)
let mut data = Vec::with_capacity(data_size);
data.extend_from_slice(&self.id_counter.to_le_bytes());
for ikm in &self.ikm_lst {
data.extend_from_slice(&ikm.as_bytes()?);
}
Ok(Base64UrlUnpadded::encode_string(&data))
} }
pub fn import(s: &str) -> Result<Self> { pub fn import(s: &str) -> Result<Self> {
let data = Base64UrlUnpadded::decode_vec(s)?; crate::storage::decode_ikm_list(s)
if data.len() % IKM_STRUCT_SIZE != 4 {
return Err(Error::ParsingIkmInvalidLength(data.len()));
}
let mut ikm_lst = Vec::with_capacity(data.len() / IKM_STRUCT_SIZE);
for ikm_slice in data[4..].chunks_exact(IKM_STRUCT_SIZE) {
ikm_lst.push(InputKeyMaterial::from_bytes(ikm_slice.try_into().unwrap())?);
}
Ok(Self {
ikm_lst,
id_counter: CounterId::from_le_bytes(data[0..4].try_into().unwrap()),
})
} }
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -195,7 +189,10 @@ mod tests {
assert_eq!(el.id, 1); assert_eq!(el.id, 1);
assert_eq!(el.is_revoked, false); assert_eq!(el.is_revoked, false);
let res = lst.add_ikm(); let res = lst.add_custom_ikm(
Scheme::XChaCha20Poly1305WithBlake3,
Duration::from_secs(crate::DEFAULT_IKM_DURATION),
);
assert!(res.is_ok()); assert!(res.is_ok());
assert_eq!(lst.id_counter, 2); assert_eq!(lst.id_counter, 2);
assert_eq!(lst.ikm_lst.len(), 2); assert_eq!(lst.ikm_lst.len(), 2);
@ -228,13 +225,13 @@ mod tests {
let res = lst.export(); let res = lst.export();
assert!(res.is_ok()); assert!(res.is_ok());
let s = res.unwrap(); let s = res.unwrap();
assert_eq!(s.len(), 82); assert_eq!(s.len(), 83);
} }
#[test] #[test]
fn import() { fn import() {
let s = let s =
"AQAAAAEAAAABAAAANGFtbdYEN0s7dzCfMm7dYeQWD64GdmuKsYSiKwppAhmkz81lAAAAACQDr2cAAAAAAA"; "AQAAAA:AQAAAAEAAAC_vYEw1ujVG5i-CtoPYSzik_6xaAq59odjPm5ij01-e6zz4mUAAAAALJGBiwAAAAAA";
let res = InputKeyMaterialList::import(s); let res = InputKeyMaterialList::import(s);
assert!(res.is_ok()); assert!(res.is_ok());
let lst = res.unwrap(); let lst = res.unwrap();
@ -246,8 +243,8 @@ mod tests {
assert_eq!( assert_eq!(
ikm.content, ikm.content,
[ [
52, 97, 109, 109, 214, 4, 55, 75, 59, 119, 48, 159, 50, 110, 221, 97, 228, 22, 15, 191, 189, 129, 48, 214, 232, 213, 27, 152, 190, 10, 218, 15, 97, 44, 226, 147, 254,
174, 6, 118, 107, 138, 177, 132, 162, 43, 10, 105, 2, 25 177, 104, 10, 185, 246, 135, 99, 62, 110, 98, 143, 77, 126, 123
] ]
); );
assert_eq!(ikm.is_revoked, false); assert_eq!(ikm.is_revoked, false);
@ -360,11 +357,16 @@ mod tests {
let mut lst = InputKeyMaterialList::new(); let mut lst = InputKeyMaterialList::new();
let _ = lst.add_ikm(); let _ = lst.add_ikm();
let _ = lst.add_ikm(); let _ = lst.add_ikm();
let _ = lst.add_ikm(); let _ = lst.add_custom_ikm(
Scheme::XChaCha20Poly1305WithBlake3,
Duration::from_secs(crate::DEFAULT_IKM_DURATION),
);
let res = lst.get_latest_ikm(); let res = lst.get_latest_ikm();
assert!(res.is_ok()); assert!(res.is_ok());
let latest_ikm = res.unwrap(); let latest_ikm = res.unwrap();
assert_eq!(latest_ikm.id, 3); assert_eq!(latest_ikm.id, 3);
assert_eq!(latest_ikm.scheme, Scheme::XChaCha20Poly1305WithBlake3);
assert_eq!(latest_ikm.content.len(), 32);
} }
#[test] #[test]

View file

@ -22,7 +22,7 @@ mod tests {
0xd0, 0x65, 0x00, 0x00, 0x00, 0x00, 0x3d, 0x82, 0x6f, 0x8b, 0x00, 0x00, 0x00, 0x00, 0xd0, 0x65, 0x00, 0x00, 0x00, 0x00, 0x3d, 0x82, 0x6f, 0x8b, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00,
]; ];
let ikm = InputKeyMaterial::from_bytes(ikm_raw).unwrap(); let ikm = InputKeyMaterial::from_bytes(&ikm_raw).unwrap();
let ctx = ["some", "context"]; let ctx = ["some", "context"];
assert_eq!( assert_eq!(

View file

@ -13,6 +13,12 @@ pub enum Scheme {
} }
impl Scheme { impl Scheme {
pub(crate) fn get_ikm_size(&self) -> usize {
match self {
Scheme::XChaCha20Poly1305WithBlake3 => 32,
}
}
pub(crate) fn get_kdf(&self) -> Box<KdfFunction> { pub(crate) fn get_kdf(&self) -> Box<KdfFunction> {
match self { match self {
Scheme::XChaCha20Poly1305WithBlake3 => Box::new(blake3::blake3_derive), Scheme::XChaCha20Poly1305WithBlake3 => Box::new(blake3::blake3_derive),

View file

@ -1,6 +1,6 @@
use crate::encryption::EncryptedData; use crate::encryption::EncryptedData;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::ikm::IkmId; use crate::ikm::{CounterId, IkmId, InputKeyMaterial, InputKeyMaterialList, IKM_BASE_STRUCT_SIZE};
use base64ct::{Base64UrlUnpadded, Encoding}; use base64ct::{Base64UrlUnpadded, Encoding};
const STORAGE_SEPARATOR: &str = ":"; const STORAGE_SEPARATOR: &str = ":";
@ -16,7 +16,20 @@ fn decode_data(s: &str) -> Result<Vec<u8>> {
Ok(Base64UrlUnpadded::decode_vec(s)?) Ok(Base64UrlUnpadded::decode_vec(s)?)
} }
pub(crate) fn encode(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String { pub(crate) fn encode_ikm_list(ikml: &InputKeyMaterialList) -> Result<String> {
let data_size = (ikml.ikm_lst.iter().fold(0, |acc, ikm| {
acc + IKM_BASE_STRUCT_SIZE + ikm.scheme.get_ikm_size()
})) + 4;
let mut ret = String::with_capacity(data_size);
ret += &encode_data(&ikml.id_counter.to_le_bytes());
for ikm in &ikml.ikm_lst {
ret += STORAGE_SEPARATOR;
ret += &encode_data(&ikm.as_bytes()?);
}
Ok(ret)
}
pub(crate) fn encode_cipher(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String {
let mut ret = String::new(); let mut ret = String::new();
ret += &encode_data(&ikm_id.to_le_bytes()); ret += &encode_data(&ikm_id.to_le_bytes());
ret += STORAGE_SEPARATOR; ret += STORAGE_SEPARATOR;
@ -26,7 +39,26 @@ pub(crate) fn encode(ikm_id: IkmId, encrypted_data: &EncryptedData) -> String {
ret ret
} }
pub(crate) fn decode(data: &str) -> Result<(IkmId, EncryptedData)> { pub(crate) fn decode_ikm_list(data: &str) -> Result<InputKeyMaterialList> {
let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect();
if v.is_empty() {
return Err(Error::ParsingEncodedDataInvalidIkmListLen(v.len()));
}
let id_data = decode_data(v[0])?;
let id_counter = CounterId::from_le_bytes(id_data[0..4].try_into().unwrap());
let mut ikm_lst = Vec::with_capacity(v.len() - 1);
for ikm_str in &v[1..] {
let raw_ikm = decode_data(ikm_str)?;
let ikm = InputKeyMaterial::from_bytes(&raw_ikm)?;
ikm_lst.push(ikm);
}
Ok(InputKeyMaterialList {
ikm_lst,
id_counter,
})
}
pub(crate) fn decode_cipher(data: &str) -> Result<(IkmId, EncryptedData)> {
let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect(); let v: Vec<&str> = data.split(STORAGE_SEPARATOR).collect();
if v.len() != NB_PARTS { if v.len() != NB_PARTS {
return Err(Error::ParsingEncodedDataInvalidPartLen(NB_PARTS, v.len())); return Err(Error::ParsingEncodedDataInvalidPartLen(NB_PARTS, v.len()));
@ -67,13 +99,13 @@ mod tests {
nonce: TEST_NONCE.into(), nonce: TEST_NONCE.into(),
ciphertext: TEST_CIPHERTEXT.into(), ciphertext: TEST_CIPHERTEXT.into(),
}; };
let s = super::encode(TEST_IKM_ID, &data); let s = super::encode_cipher(TEST_IKM_ID, &data);
assert_eq!(&s, TEST_STR); assert_eq!(&s, TEST_STR);
} }
#[test] #[test]
fn decode() { fn decode() {
let res = super::decode(TEST_STR); let res = super::decode_cipher(TEST_STR);
assert!(res.is_ok()); assert!(res.is_ok());
let (id, data) = res.unwrap(); let (id, data) = res.unwrap();
assert_eq!(id, TEST_IKM_ID); assert_eq!(id, TEST_IKM_ID);
@ -87,8 +119,8 @@ mod tests {
nonce: TEST_NONCE.into(), nonce: TEST_NONCE.into(),
ciphertext: TEST_CIPHERTEXT.into(), ciphertext: TEST_CIPHERTEXT.into(),
}; };
let s = super::encode(TEST_IKM_ID, &data); let s = super::encode_cipher(TEST_IKM_ID, &data);
let (id, decoded_data) = super::decode(&s).unwrap(); let (id, decoded_data) = super::decode_cipher(&s).unwrap();
assert_eq!(id, TEST_IKM_ID); assert_eq!(id, TEST_IKM_ID);
assert_eq!(decoded_data.nonce, data.nonce); assert_eq!(decoded_data.nonce, data.nonce);
assert_eq!(decoded_data.ciphertext, data.ciphertext); assert_eq!(decoded_data.ciphertext, data.ciphertext);
@ -96,8 +128,8 @@ mod tests {
#[test] #[test]
fn decode_encode() { fn decode_encode() {
let (id, data) = super::decode(TEST_STR).unwrap(); let (id, data) = super::decode_cipher(TEST_STR).unwrap();
let s = super::encode(id, &data); let s = super::encode_cipher(id, &data);
assert_eq!(&s, TEST_STR); assert_eq!(&s, TEST_STR);
} }
} }