const_oid/
encoder.rs

1//! OID encoder with `const` support.
2
3use crate::{
4    Arc, Buffer, Error, ObjectIdentifier, Result,
5    arcs::{ARC_MAX_FIRST, ARC_MAX_SECOND},
6};
7
8/// BER/DER encoder.
9#[derive(Debug)]
10pub(crate) struct Encoder<const MAX_SIZE: usize> {
11    /// Current state.
12    state: State,
13
14    /// Bytes of the OID being BER-encoded in-progress.
15    bytes: [u8; MAX_SIZE],
16
17    /// Current position within the byte buffer.
18    cursor: usize,
19}
20
21/// Current state of the encoder.
22#[derive(Debug)]
23enum State {
24    /// Initial state - no arcs yet encoded.
25    Initial,
26
27    /// First arc has been supplied and stored as the wrapped [`Arc`].
28    FirstArc(Arc),
29
30    /// Encoding base 128 body of the OID.
31    Body,
32}
33
34impl<const MAX_SIZE: usize> Encoder<MAX_SIZE> {
35    /// Create a new encoder initialized to an empty default state.
36    pub(crate) const fn new() -> Self {
37        Self {
38            state: State::Initial,
39            bytes: [0u8; MAX_SIZE],
40            cursor: 0,
41        }
42    }
43
44    /// Extend an existing OID.
45    pub(crate) const fn extend(oid: ObjectIdentifier<MAX_SIZE>) -> Self {
46        Self {
47            state: State::Body,
48            bytes: oid.ber.bytes,
49            cursor: oid.ber.length as usize,
50        }
51    }
52
53    /// Encode an [`Arc`] as base 128 into the internal buffer.
54    pub(crate) const fn arc(mut self, arc: Arc) -> Result<Self> {
55        match self.state {
56            State::Initial => {
57                if arc > ARC_MAX_FIRST {
58                    return Err(Error::ArcInvalid { arc });
59                }
60
61                self.state = State::FirstArc(arc);
62                Ok(self)
63            }
64            State::FirstArc(first_arc) => {
65                if arc > ARC_MAX_SECOND {
66                    return Err(Error::ArcInvalid { arc });
67                }
68
69                self.state = State::Body;
70                self.bytes[0] = checked_add!(
71                    checked_mul!(checked_add!(ARC_MAX_SECOND, 1), first_arc),
72                    arc
73                ) as u8;
74                self.cursor = 1;
75                Ok(self)
76            }
77            State::Body => self.encode_base128(arc),
78        }
79    }
80
81    /// Finish encoding an OID.
82    pub(crate) const fn finish(self) -> Result<ObjectIdentifier<MAX_SIZE>> {
83        if self.cursor == 0 {
84            return Err(Error::Empty);
85        }
86
87        let ber = Buffer {
88            bytes: self.bytes,
89            length: self.cursor as u8,
90        };
91
92        Ok(ObjectIdentifier { ber })
93    }
94
95    /// Encode base 128.
96    const fn encode_base128(mut self, arc: Arc) -> Result<Self> {
97        let nbytes = base128_len(arc);
98        let end_pos = checked_add!(self.cursor, nbytes);
99
100        if end_pos > MAX_SIZE {
101            return Err(Error::Length);
102        }
103
104        let mut i = 0;
105        while i < nbytes {
106            // TODO(tarcieri): use `?` when stable in `const fn`
107            self.bytes[self.cursor] = match base128_byte(arc, i, nbytes) {
108                Ok(byte) => byte,
109                Err(e) => return Err(e),
110            };
111            self.cursor = checked_add!(self.cursor, 1);
112            i = checked_add!(i, 1);
113        }
114
115        Ok(self)
116    }
117}
118
119/// Compute the length of an arc when encoded in base 128.
120const fn base128_len(arc: Arc) -> usize {
121    match arc {
122        0..=0x7f => 1,              // up to 7 bits
123        0x80..=0x3fff => 2,         // up to 14 bits
124        0x4000..=0x1fffff => 3,     // up to 21 bits
125        0x200000..=0x0fffffff => 4, // up to 28 bits
126        _ => 5,
127    }
128}
129
130/// Compute the big endian base 128 encoding of the given [`Arc`] at the given byte.
131const fn base128_byte(arc: Arc, pos: usize, total: usize) -> Result<u8> {
132    debug_assert!(pos < total);
133    let last_byte = checked_add!(pos, 1) == total;
134    let mask = if last_byte { 0 } else { 0b10000000 };
135    let shift = checked_mul!(checked_sub!(checked_sub!(total, pos), 1), 7);
136    Ok(((arc >> shift) & 0b1111111) as u8 | mask)
137}
138
139#[cfg(test)]
140#[allow(clippy::unwrap_used)]
141mod tests {
142    use super::Encoder;
143    use hex_literal::hex;
144
145    /// OID `1.2.840.10045.2.1` encoded as ASN.1 BER/DER
146    const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201");
147
148    #[test]
149    fn base128_byte() {
150        let example_arc = 0x44332211;
151        assert_eq!(super::base128_len(example_arc), 5);
152        assert_eq!(super::base128_byte(example_arc, 0, 5).unwrap(), 0b10000100);
153        assert_eq!(super::base128_byte(example_arc, 1, 5).unwrap(), 0b10100001);
154        assert_eq!(super::base128_byte(example_arc, 2, 5).unwrap(), 0b11001100);
155        assert_eq!(super::base128_byte(example_arc, 3, 5).unwrap(), 0b11000100);
156        assert_eq!(super::base128_byte(example_arc, 4, 5).unwrap(), 0b10001);
157    }
158
159    #[test]
160    fn encode() {
161        let encoder = Encoder::<7>::new();
162        let encoder = encoder.arc(1).unwrap();
163        let encoder = encoder.arc(2).unwrap();
164        let encoder = encoder.arc(840).unwrap();
165        let encoder = encoder.arc(10045).unwrap();
166        let encoder = encoder.arc(2).unwrap();
167        let encoder = encoder.arc(1).unwrap();
168        assert_eq!(&encoder.bytes[..encoder.cursor], EXAMPLE_OID_BER);
169    }
170}