use crate::{
asn1::*, Encode, EncodeValue, ErrorKind, Header, Length, Result, Tag, TagMode, TagNumber,
Tagged, Writer,
};
#[derive(Debug)]
pub struct SliceWriter<'a> {
bytes: &'a mut [u8],
failed: bool,
position: Length,
}
impl<'a> SliceWriter<'a> {
pub fn new(bytes: &'a mut [u8]) -> Self {
Self {
bytes,
failed: false,
position: Length::ZERO,
}
}
pub fn encode<T: Encode>(&mut self, encodable: &T) -> Result<()> {
if self.is_failed() {
self.error(ErrorKind::Failed)?
}
encodable.encode(self).map_err(|e| {
self.failed = true;
e.nested(self.position)
})
}
pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
self.failed = true;
Err(kind.at(self.position))
}
pub fn is_failed(&self) -> bool {
self.failed
}
pub fn finish(self) -> Result<&'a [u8]> {
let position = self.position;
if self.is_failed() {
return Err(ErrorKind::Failed.at(position));
}
self.bytes
.get(..usize::try_from(position)?)
.ok_or_else(|| ErrorKind::Overlength.at(position))
}
pub fn context_specific<T>(
&mut self,
tag_number: TagNumber,
tag_mode: TagMode,
value: &T,
) -> Result<()>
where
T: EncodeValue + Tagged,
{
ContextSpecificRef {
tag_number,
tag_mode,
value,
}
.encode(self)
}
pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
where
F: FnOnce(&mut SliceWriter<'_>) -> Result<()>,
{
Header::new(Tag::Sequence, length).and_then(|header| header.encode(self))?;
let mut nested_encoder = SliceWriter::new(self.reserve(length)?);
f(&mut nested_encoder)?;
if nested_encoder.finish()?.len() == usize::try_from(length)? {
Ok(())
} else {
self.error(ErrorKind::Length { tag: Tag::Sequence })
}
}
fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
if self.is_failed() {
return Err(ErrorKind::Failed.at(self.position));
}
let len = len
.try_into()
.or_else(|_| self.error(ErrorKind::Overflow))?;
let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
let slice = self
.bytes
.get_mut(self.position.try_into()?..end.try_into()?)
.ok_or_else(|| ErrorKind::Overlength.at(end))?;
self.position = end;
Ok(slice)
}
}
impl<'a> Writer for SliceWriter<'a> {
fn write(&mut self, slice: &[u8]) -> Result<()> {
self.reserve(slice.len())?.copy_from_slice(slice);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::SliceWriter;
use crate::{Encode, ErrorKind, Length};
#[test]
fn overlength_message() {
let mut buffer = [];
let mut writer = SliceWriter::new(&mut buffer);
let err = false.encode(&mut writer).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Overlength);
assert_eq!(err.position(), Some(Length::ONE));
}
}