use crate::ber::*;
use crate::der_constraint_fail_if;
use crate::error::*;
#[cfg(feature = "std")]
use crate::ToDer;
use crate::{BerParser, Class, DerParser, DynTagged, FromBer, FromDer, Length, Tag, ToStatic};
use alloc::borrow::Cow;
use core::convert::TryFrom;
use nom::bytes::streaming::take;
#[derive(Clone, Debug)]
pub struct Header<'a> {
pub(crate) class: Class,
pub(crate) constructed: bool,
pub(crate) tag: Tag,
pub(crate) length: Length,
pub(crate) raw_tag: Option<Cow<'a, [u8]>>,
}
impl<'a> Header<'a> {
pub const fn new(class: Class, constructed: bool, tag: Tag, length: Length) -> Self {
Header {
tag,
constructed,
class,
length,
raw_tag: None,
}
}
#[inline]
pub const fn new_simple(tag: Tag) -> Self {
let constructed = matches!(tag, Tag::Sequence | Tag::Set);
Self::new(Class::Universal, constructed, tag, Length::Definite(0))
}
#[inline]
pub fn with_class(self, class: Class) -> Self {
Self { class, ..self }
}
#[inline]
pub fn with_constructed(self, constructed: bool) -> Self {
Self {
constructed,
..self
}
}
#[inline]
pub fn with_tag(self, tag: Tag) -> Self {
Self { tag, ..self }
}
#[inline]
pub fn with_length(self, length: Length) -> Self {
Self { length, ..self }
}
#[inline]
pub fn with_raw_tag(self, raw_tag: Option<Cow<'a, [u8]>>) -> Self {
Header { raw_tag, ..self }
}
#[inline]
pub const fn class(&self) -> Class {
self.class
}
#[inline]
pub const fn constructed(&self) -> bool {
self.constructed
}
#[inline]
pub const fn tag(&self) -> Tag {
self.tag
}
#[inline]
pub const fn length(&self) -> Length {
self.length
}
#[inline]
pub fn raw_tag(&self) -> Option<&[u8]> {
self.raw_tag.as_ref().map(|cow| cow.as_ref())
}
#[inline]
pub const fn is_primitive(&self) -> bool {
!self.constructed
}
#[inline]
pub const fn is_constructed(&self) -> bool {
self.constructed
}
#[inline]
pub const fn assert_class(&self, class: Class) -> Result<()> {
self.class.assert_eq(class)
}
#[inline]
pub const fn assert_tag(&self, tag: Tag) -> Result<()> {
self.tag.assert_eq(tag)
}
#[inline]
pub const fn assert_primitive(&self) -> Result<()> {
if self.is_primitive() {
Ok(())
} else {
Err(Error::ConstructUnexpected)
}
}
#[inline]
pub const fn assert_constructed(&self) -> Result<()> {
if !self.is_primitive() {
Ok(())
} else {
Err(Error::ConstructExpected)
}
}
#[inline]
pub const fn is_universal(&self) -> bool {
self.class as u8 == Class::Universal as u8
}
#[inline]
pub const fn is_application(&self) -> bool {
self.class as u8 == Class::Application as u8
}
#[inline]
pub const fn is_contextspecific(&self) -> bool {
self.class as u8 == Class::ContextSpecific as u8
}
#[inline]
pub const fn is_private(&self) -> bool {
self.class as u8 == Class::Private as u8
}
#[inline]
pub const fn assert_definite(&self) -> Result<()> {
if self.length.is_definite() {
Ok(())
} else {
Err(Error::DerConstraintFailed(DerConstraint::IndefiniteLength))
}
}
#[inline]
pub fn parse_ber_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
BerParser::get_object_content(i, self, 8)
}
#[inline]
pub fn parse_der_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
self.assert_definite()?;
DerParser::get_object_content(i, self, 8)
}
}
impl From<Tag> for Header<'_> {
#[inline]
fn from(tag: Tag) -> Self {
let constructed = matches!(tag, Tag::Sequence | Tag::Set);
Self::new(Class::Universal, constructed, tag, Length::Definite(0))
}
}
impl<'a> ToStatic for Header<'a> {
type Owned = Header<'static>;
fn to_static(&self) -> Self::Owned {
let raw_tag: Option<Cow<'static, [u8]>> =
self.raw_tag.as_ref().map(|b| Cow::Owned(b.to_vec()));
Header {
tag: self.tag,
constructed: self.constructed,
class: self.class,
length: self.length,
raw_tag,
}
}
}
impl<'a> FromBer<'a> for Header<'a> {
fn from_ber(bytes: &'a [u8]) -> ParseResult<Self> {
let (i1, el) = parse_identifier(bytes)?;
let class = match Class::try_from(el.0) {
Ok(c) => c,
Err(_) => unreachable!(), };
let (i2, len) = parse_ber_length_byte(i1)?;
let (i3, len) = match (len.0, len.1) {
(0, l1) => {
(i2, Length::Definite(usize::from(l1)))
}
(_, 0) => {
if el.1 == 0 {
return Err(nom::Err::Error(Error::ConstructExpected));
}
(i2, Length::Indefinite)
}
(_, l1) => {
if l1 == 0b0111_1111 {
return Err(nom::Err::Error(Error::InvalidLength));
}
let (i3, llen) = take(l1)(i2)?;
match bytes_to_u64(llen) {
Ok(l) => {
let l =
usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
(i3, Length::Definite(l))
}
Err(_) => {
return Err(nom::Err::Error(Error::InvalidLength));
}
}
}
};
let constructed = el.1 != 0;
let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
Ok((i3, hdr))
}
}
impl<'a> FromDer<'a> for Header<'a> {
fn from_der(bytes: &'a [u8]) -> ParseResult<Self> {
let (i1, el) = parse_identifier(bytes)?;
let class = match Class::try_from(el.0) {
Ok(c) => c,
Err(_) => unreachable!(), };
let (i2, len) = parse_ber_length_byte(i1)?;
let (i3, len) = match (len.0, len.1) {
(0, l1) => {
(i2, Length::Definite(usize::from(l1)))
}
(_, 0) => {
return Err(nom::Err::Error(Error::DerConstraintFailed(
DerConstraint::IndefiniteLength,
)));
}
(_, l1) => {
if l1 == 0b0111_1111 {
return Err(nom::Err::Error(Error::InvalidLength));
}
der_constraint_fail_if!(
&i[1..],
len.1 == 0 && el.1 != 1,
DerConstraint::NotConstructed
);
let (i3, llen) = take(l1)(i2)?;
match bytes_to_u64(llen) {
Ok(l) => {
let l =
usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
(i3, Length::Definite(l))
}
Err(_) => {
return Err(nom::Err::Error(Error::InvalidLength));
}
}
}
};
let constructed = el.1 != 0;
let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
Ok((i3, hdr))
}
}
impl DynTagged for (Class, bool, Tag) {
fn tag(&self) -> Tag {
self.2
}
}
#[cfg(feature = "std")]
impl ToDer for (Class, bool, Tag) {
fn to_der_len(&self) -> Result<usize> {
let (_, _, tag) = self;
match tag.0 {
0..=30 => Ok(1),
t => {
let mut sz = 1;
let mut val = t;
loop {
if val <= 127 {
return Ok(sz + 1);
} else {
val >>= 7;
sz += 1;
}
}
}
}
}
fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
let (class, constructed, tag) = self;
let b0 = (*class as u8) << 6;
let b0 = b0 | if *constructed { 0b10_0000 } else { 0 };
if tag.0 > 30 {
let mut val = tag.0;
const BUF_SZ: usize = 8;
let mut buffer = [0u8; BUF_SZ];
let mut current_index = BUF_SZ - 1;
let b0 = b0 | 0b1_1111;
let mut sz = writer.write(&[b0])?;
buffer[current_index] = (val & 0x7f) as u8;
val >>= 7;
while val > 0 {
current_index -= 1;
if current_index == 0 {
return Err(SerializeError::InvalidLength);
}
buffer[current_index] = (val & 0x7f) as u8 | 0x80;
val >>= 7;
}
sz += writer.write(&buffer[current_index..])?;
Ok(sz)
} else {
let b0 = b0 | (tag.0 as u8);
let sz = writer.write(&[b0])?;
Ok(sz)
}
}
fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
Ok(0)
}
}
impl DynTagged for Header<'_> {
fn tag(&self) -> Tag {
self.tag
}
}
#[cfg(feature = "std")]
impl ToDer for Header<'_> {
fn to_der_len(&self) -> Result<usize> {
let tag_len = (self.class, self.constructed, self.tag).to_der_len()?;
let len_len = self.length.to_der_len()?;
Ok(tag_len + len_len)
}
fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
let sz = (self.class, self.constructed, self.tag).write_der_header(writer)?;
let sz = sz + self.length.write_der_header(writer)?;
Ok(sz)
}
fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
Ok(0)
}
fn write_der_raw(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
let sz = match &self.raw_tag {
Some(t) => writer.write(t)?,
None => (self.class, self.constructed, self.tag).write_der_header(writer)?,
};
let sz = sz + self.length.write_der_header(writer)?;
Ok(sz)
}
}
impl<'a> PartialEq<Header<'a>> for Header<'a> {
fn eq(&self, other: &Header) -> bool {
self.class == other.class
&& self.tag == other.tag
&& self.constructed == other.constructed
&& {
if self.length.is_null() && other.length.is_null() {
self.length == other.length
} else {
true
}
}
&& {
if self.raw_tag.as_ref().xor(other.raw_tag.as_ref()).is_none() {
self.raw_tag == other.raw_tag
} else {
true
}
}
}
}
impl Eq for Header<'_> {}
#[cfg(test)]
mod tests {
use crate::*;
use hex_literal::hex;
#[test]
fn methods_header() {
let input = &hex! {"02 01 00"};
let (rem, header) = Header::from_ber(input).expect("parsing header failed");
assert_eq!(header.class(), Class::Universal);
assert_eq!(header.tag(), Tag::Integer);
assert!(header.assert_primitive().is_ok());
assert!(header.assert_constructed().is_err());
assert!(header.is_universal());
assert!(!header.is_application());
assert!(!header.is_private());
assert_eq!(rem, &input[2..]);
let hdr2 = Header::new_simple(Tag::Integer);
assert_eq!(header, hdr2);
let hdr3 = hdr2
.with_class(Class::ContextSpecific)
.with_constructed(true)
.with_length(Length::Definite(1));
assert!(hdr3.constructed());
assert!(hdr3.is_constructed());
assert!(hdr3.assert_constructed().is_ok());
assert!(hdr3.is_contextspecific());
let xx = hdr3.to_der_vec().expect("serialize failed");
assert_eq!(&xx, &[0xa2, 0x01]);
let hdr4 = hdr3.with_length(Length::Indefinite);
assert!(hdr4.assert_definite().is_err());
let xx = hdr4.to_der_vec().expect("serialize failed");
assert_eq!(&xx, &[0xa2, 0x80]);
let hdr = Header::new_simple(Tag(2)).with_length(Length::Definite(1));
let (_, r) = hdr.parse_ber_content(&input[2..]).unwrap();
assert_eq!(r, &input[2..]);
let (_, r) = hdr.parse_der_content(&input[2..]).unwrap();
assert_eq!(r, &input[2..]);
}
}