diff options
Diffstat (limited to 'vendor/fdeflate/src/decompress.rs')
-rw-r--r-- | vendor/fdeflate/src/decompress.rs | 1270 |
1 files changed, 1270 insertions, 0 deletions
diff --git a/vendor/fdeflate/src/decompress.rs b/vendor/fdeflate/src/decompress.rs new file mode 100644 index 0000000..d38f63b --- /dev/null +++ b/vendor/fdeflate/src/decompress.rs @@ -0,0 +1,1270 @@ +use std::convert::TryInto; + +use simd_adler32::Adler32; + +use crate::tables::{ + self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FDEFLATE_DIST_DECODE_TABLE, + FDEFLATE_LITLEN_DECODE_TABLE, FIXED_CODE_LENGTHS, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, +}; + +/// An error encountered while decompressing a deflate stream. +#[derive(Debug, PartialEq)] +pub enum DecompressionError { + /// The zlib header is corrupt. + BadZlibHeader, + /// All input was consumed, but the end of the stream hasn't been reached. + InsufficientInput, + /// A block header specifies an invalid block type. + InvalidBlockType, + /// An uncompressed block's NLEN value is invalid. + InvalidUncompressedBlockLength, + /// Too many literals were specified. + InvalidHlit, + /// Too many distance codes were specified. + InvalidHdist, + /// Attempted to repeat a previous code before reading any codes, or past the end of the code + /// lengths. + InvalidCodeLengthRepeat, + /// The stream doesn't specify a valid huffman tree. + BadCodeLengthHuffmanTree, + /// The stream doesn't specify a valid huffman tree. + BadLiteralLengthHuffmanTree, + /// The stream doesn't specify a valid huffman tree. + BadDistanceHuffmanTree, + /// The stream contains a literal/length code that was not allowed by the header. + InvalidLiteralLengthCode, + /// The stream contains a distance code that was not allowed by the header. + InvalidDistanceCode, + /// The stream contains contains back-reference as the first symbol. + InputStartsWithRun, + /// The stream contains a back-reference that is too far back. + DistanceTooFarBack, + /// The deflate stream checksum is incorrect. + WrongChecksum, + /// Extra input data. + ExtraInput, +} + +struct BlockHeader { + hlit: usize, + hdist: usize, + hclen: usize, + num_lengths_read: usize, + + /// Low 3-bits are code length code length, high 5-bits are code length code. + table: [u8; 128], + code_lengths: [u8; 320], +} + +const LITERAL_ENTRY: u32 = 0x8000; +const EXCEPTIONAL_ENTRY: u32 = 0x4000; +const SECONDARY_TABLE_ENTRY: u32 = 0x2000; + +/// The Decompressor state for a compressed block. +/// +/// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is +/// split into 4 sections: +/// +/// aaaaaaaa_bbbbbbbb_1000yyyy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal) +/// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length) +/// 00000000_00000000_01000000_0000xxxx x = input_advance_bits (EOF) +/// 0000xxxx_xxxxxxxx_01100000_00000000 x = secondary_table_index +/// 00000000_00000000_01000000_00000000 invalid code +/// +/// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning. +/// +/// 00000000_00000000_00000000_00000000 symbol is more than 9 bits +/// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base +#[repr(align(64))] +#[derive(Eq, PartialEq, Debug)] +struct CompressedBlock { + litlen_table: [u32; 4096], + dist_table: [u32; 512], + + dist_symbol_lengths: [u8; 30], + dist_symbol_masks: [u16; 30], + dist_symbol_codes: [u16; 30], + + secondary_table: Vec<u16>, + eof_code: u16, + eof_mask: u16, + eof_bits: u8, +} + +const FDEFLATE_COMPRESSED_BLOCK: CompressedBlock = CompressedBlock { + litlen_table: FDEFLATE_LITLEN_DECODE_TABLE, + dist_table: FDEFLATE_DIST_DECODE_TABLE, + dist_symbol_lengths: [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + dist_symbol_masks: [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + dist_symbol_codes: [ + 0, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, + ], + secondary_table: Vec::new(), + eof_code: 0x8ff, + eof_mask: 0xfff, + eof_bits: 0xc, +}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum State { + ZlibHeader, + BlockHeader, + CodeLengthCodes, + CodeLengths, + CompressedData, + UncompressedData, + Checksum, + Done, +} + +/// Decompressor for arbitrary zlib streams. +pub struct Decompressor { + /// State for decoding a compressed block. + compression: CompressedBlock, + // State for decoding a block header. + header: BlockHeader, + // Number of bytes left for uncompressed block. + uncompressed_bytes_left: u16, + + buffer: u64, + nbits: u8, + + queued_rle: Option<(u8, usize)>, + queued_backref: Option<(usize, usize)>, + last_block: bool, + + state: State, + checksum: Adler32, + ignore_adler32: bool, +} + +impl Default for Decompressor { + fn default() -> Self { + Self::new() + } +} + +impl Decompressor { + /// Create a new decompressor. + pub fn new() -> Self { + Self { + buffer: 0, + nbits: 0, + compression: CompressedBlock { + litlen_table: [0; 4096], + dist_table: [0; 512], + secondary_table: Vec::new(), + dist_symbol_lengths: [0; 30], + dist_symbol_masks: [0; 30], + dist_symbol_codes: [0xffff; 30], + eof_code: 0, + eof_mask: 0, + eof_bits: 0, + }, + header: BlockHeader { + hlit: 0, + hdist: 0, + hclen: 0, + table: [0; 128], + num_lengths_read: 0, + code_lengths: [0; 320], + }, + uncompressed_bytes_left: 0, + queued_rle: None, + queued_backref: None, + checksum: Adler32::new(), + state: State::ZlibHeader, + last_block: false, + ignore_adler32: false, + } + } + + /// Ignore the checksum at the end of the stream. + pub fn ignore_adler32(&mut self) { + self.ignore_adler32 = true; + } + + fn fill_buffer(&mut self, input: &mut &[u8]) { + if input.len() >= 8 { + self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits; + *input = &mut &input[(63 - self.nbits as usize) / 8..]; + self.nbits |= 56; + } else { + let nbytes = input.len().min((63 - self.nbits as usize) / 8); + let mut input_data = [0; 8]; + input_data[..nbytes].copy_from_slice(&input[..nbytes]); + self.buffer |= u64::from_le_bytes(input_data) + .checked_shl(self.nbits as u32) + .unwrap_or(0); + self.nbits += nbytes as u8 * 8; + *input = &mut &input[nbytes..]; + } + } + + fn peak_bits(&mut self, nbits: u8) -> u64 { + debug_assert!(nbits <= 56 && nbits <= self.nbits); + self.buffer & ((1u64 << nbits) - 1) + } + fn consume_bits(&mut self, nbits: u8) { + debug_assert!(self.nbits >= nbits); + self.buffer >>= nbits; + self.nbits -= nbits; + } + + fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { + self.fill_buffer(remaining_input); + if self.nbits < 3 { + return Ok(()); + } + + let start = self.peak_bits(3); + self.last_block = start & 1 != 0; + match start >> 1 { + 0b00 => { + let align_bits = (self.nbits - 3) % 8; + let header_bits = 3 + 32 + align_bits; + if self.nbits < header_bits { + return Ok(()); + } + + let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16; + let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16; + if nlen != !len { + return Err(DecompressionError::InvalidUncompressedBlockLength); + } + + self.state = State::UncompressedData; + self.uncompressed_bytes_left = len; + self.consume_bits(header_bits); + Ok(()) + } + 0b01 => { + self.consume_bits(3); + // TODO: Do this statically rather than every time. + Self::build_tables(288, &FIXED_CODE_LENGTHS, &mut self.compression, 6)?; + self.state = State::CompressedData; + Ok(()) + } + 0b10 => { + if self.nbits < 17 { + return Ok(()); + } + + self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257; + self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1; + self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4; + if self.header.hlit > 286 { + return Err(DecompressionError::InvalidHlit); + } + if self.header.hdist > 30 { + return Err(DecompressionError::InvalidHdist); + } + + self.consume_bits(17); + self.state = State::CodeLengthCodes; + Ok(()) + } + 0b11 => Err(DecompressionError::InvalidBlockType), + _ => unreachable!(), + } + } + + fn read_code_length_codes( + &mut self, + remaining_input: &mut &[u8], + ) -> Result<(), DecompressionError> { + self.fill_buffer(remaining_input); + if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen { + return Ok(()); + } + + let mut code_length_lengths = [0; 19]; + for i in 0..self.header.hclen { + code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8; + self.consume_bits(3); + + // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds + // between 56 and 63 bits total. + if i == 17 { + self.fill_buffer(remaining_input); + } + } + let code_length_codes: [u16; 19] = crate::compute_codes(&code_length_lengths) + .ok_or(DecompressionError::BadCodeLengthHuffmanTree)?; + + self.header.table = [255; 128]; + for i in 0..19 { + let length = code_length_lengths[i]; + if length > 0 { + let mut j = code_length_codes[i]; + while j < 128 { + self.header.table[j as usize] = ((i as u8) << 3) | length; + j += 1 << length; + } + } + } + + self.state = State::CodeLengths; + self.header.num_lengths_read = 0; + Ok(()) + } + + fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { + let total_lengths = self.header.hlit + self.header.hdist; + while self.header.num_lengths_read < total_lengths { + self.fill_buffer(remaining_input); + if self.nbits < 7 { + return Ok(()); + } + + let code = self.peak_bits(7); + let entry = self.header.table[code as usize]; + let length = entry & 0x7; + let symbol = entry >> 3; + + debug_assert!(length != 0); + match symbol { + 0..=15 => { + self.header.code_lengths[self.header.num_lengths_read] = symbol; + self.header.num_lengths_read += 1; + self.consume_bits(length); + } + 16..=18 => { + let (base_repeat, extra_bits) = match symbol { + 16 => (3, 2), + 17 => (3, 3), + 18 => (11, 7), + _ => unreachable!(), + }; + + if self.nbits < length + extra_bits { + return Ok(()); + } + + let value = match symbol { + 16 => { + self.header.code_lengths[self + .header + .num_lengths_read + .checked_sub(1) + .ok_or(DecompressionError::InvalidCodeLengthRepeat)?] + // TODO: is this right? + } + 17 => 0, + 18 => 0, + _ => unreachable!(), + }; + + let repeat = + (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat; + if self.header.num_lengths_read + repeat > total_lengths { + return Err(DecompressionError::InvalidCodeLengthRepeat); + } + + for i in 0..repeat { + self.header.code_lengths[self.header.num_lengths_read + i] = value; + } + self.header.num_lengths_read += repeat; + self.consume_bits(length + extra_bits); + } + _ => unreachable!(), + } + } + + self.header + .code_lengths + .copy_within(self.header.hlit..total_lengths, 288); + for i in self.header.hlit..288 { + self.header.code_lengths[i] = 0; + } + for i in 288 + self.header.hdist..320 { + self.header.code_lengths[i] = 0; + } + + if self.header.hdist == 1 + && self.header.code_lengths[..286] == tables::HUFFMAN_LENGTHS + && self.header.code_lengths[288] == 1 + { + self.compression = FDEFLATE_COMPRESSED_BLOCK; + } else { + Self::build_tables( + self.header.hlit, + &self.header.code_lengths, + &mut self.compression, + 6, + )?; + } + self.state = State::CompressedData; + Ok(()) + } + + fn build_tables( + hlit: usize, + code_lengths: &[u8], + compression: &mut CompressedBlock, + max_search_bits: u8, + ) -> Result<(), DecompressionError> { + // Build the literal/length code table. + let lengths = &code_lengths[..288]; + let codes: [u16; 288] = crate::compute_codes(&lengths.try_into().unwrap()) + .ok_or(DecompressionError::BadLiteralLengthHuffmanTree)?; + + let table_bits = lengths.iter().cloned().max().unwrap().min(12).max(6); + let table_size = 1 << table_bits; + + for i in 0..256 { + let code = codes[i]; + let length = lengths[i]; + let mut j = code; + + while j < table_size && length != 0 && length <= 12 { + compression.litlen_table[j as usize] = + ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8) | length as u32; + j += 1 << length; + } + + if length > 0 && length <= max_search_bits { + for ii in 0..256 { + let code2 = codes[ii]; + let length2 = lengths[ii]; + if length2 != 0 && length + length2 <= table_bits { + let mut j = code | (code2 << length); + + while j < table_size { + compression.litlen_table[j as usize] = (ii as u32) << 24 + | (i as u32) << 16 + | LITERAL_ENTRY + | (2 << 8) + | ((length + length2) as u32); + j += 1 << (length + length2); + } + } + } + } + } + + if lengths[256] != 0 && lengths[256] <= 12 { + let mut j = codes[256]; + while j < table_size { + compression.litlen_table[j as usize] = EXCEPTIONAL_ENTRY | lengths[256] as u32; + j += 1 << lengths[256]; + } + } + + let table_size = table_size as usize; + for i in (table_size..4096).step_by(table_size) { + compression.litlen_table.copy_within(0..table_size, i); + } + + compression.eof_code = codes[256]; + compression.eof_mask = (1 << lengths[256]) - 1; + compression.eof_bits = lengths[256]; + + for i in 257..hlit { + let code = codes[i]; + let length = lengths[i]; + if length != 0 && length <= 12 { + let mut j = code; + while j < 4096 { + compression.litlen_table[j as usize] = if i < 286 { + (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16 + | (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8 + | length as u32 + } else { + EXCEPTIONAL_ENTRY + }; + j += 1 << length; + } + } + } + + for i in 0..hlit { + if lengths[i] > 12 { + compression.litlen_table[(codes[i] & 0xfff) as usize] = u32::MAX; + } + } + + let mut secondary_table_len = 0; + for i in 0..hlit { + if lengths[i] > 12 { + let j = (codes[i] & 0xfff) as usize; + if compression.litlen_table[j] == u32::MAX { + compression.litlen_table[j] = + (secondary_table_len << 16) | EXCEPTIONAL_ENTRY | SECONDARY_TABLE_ENTRY; + secondary_table_len += 8; + } + } + } + assert!(secondary_table_len <= 0x7ff); + compression.secondary_table = vec![0; secondary_table_len as usize]; + for i in 0..hlit { + let code = codes[i]; + let length = lengths[i]; + if length > 12 { + let j = (codes[i] & 0xfff) as usize; + let k = (compression.litlen_table[j] >> 16) as usize; + + let mut s = code >> 12; + while s < 8 { + debug_assert_eq!(compression.secondary_table[k + s as usize], 0); + compression.secondary_table[k + s as usize] = + ((i as u16) << 4) | (length as u16); + s += 1 << (length - 12); + } + } + } + debug_assert!(compression + .secondary_table + .iter() + .all(|&x| x != 0 && (x & 0xf) > 12)); + + // Build the distance code table. + let lengths = &code_lengths[288..320]; + if lengths == [0; 32] { + compression.dist_symbol_masks = [0; 30]; + compression.dist_symbol_codes = [0xffff; 30]; + compression.dist_table.fill(0); + } else { + let codes: [u16; 32] = match crate::compute_codes(&lengths.try_into().unwrap()) { + Some(codes) => codes, + None => { + if lengths.iter().filter(|&&l| l != 0).count() != 1 { + return Err(DecompressionError::BadDistanceHuffmanTree); + } + [0; 32] + } + }; + + compression.dist_symbol_codes.copy_from_slice(&codes[..30]); + compression + .dist_symbol_lengths + .copy_from_slice(&lengths[..30]); + compression.dist_table.fill(0); + for i in 0..30 { + let length = lengths[i]; + let code = codes[i]; + if length == 0 { + compression.dist_symbol_masks[i] = 0; + compression.dist_symbol_codes[i] = 0xffff; + } else { + compression.dist_symbol_masks[i] = (1 << lengths[i]) - 1; + if lengths[i] <= 9 { + let mut j = code; + while j < 512 { + compression.dist_table[j as usize] = (DIST_SYM_TO_DIST_BASE[i] as u32) + << 16 + | (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8 + | length as u32; + j += 1 << lengths[i]; + } + } + } + } + } + + Ok(()) + } + + fn read_compressed( + &mut self, + remaining_input: &mut &[u8], + output: &mut [u8], + mut output_index: usize, + ) -> Result<usize, DecompressionError> { + while let State::CompressedData = self.state { + self.fill_buffer(remaining_input); + if output_index == output.len() { + break; + } + + let mut bits = self.buffer; + let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize]; + let litlen_code_bits = litlen_entry as u8; + + if litlen_entry & LITERAL_ENTRY != 0 { + // Ultra-fast path: do 3 more consecutive table lookups and bail if any of them need the slow path. + if self.nbits >= 48 { + let litlen_entry2 = + self.compression.litlen_table[(bits >> litlen_code_bits & 0xfff) as usize]; + let litlen_code_bits2 = litlen_entry2 as u8; + let litlen_entry3 = self.compression.litlen_table + [(bits >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize]; + let litlen_code_bits3 = litlen_entry3 as u8; + let litlen_entry4 = self.compression.litlen_table[(bits + >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3) + & 0xfff) + as usize]; + let litlen_code_bits4 = litlen_entry4 as u8; + if litlen_entry2 & litlen_entry3 & litlen_entry4 & LITERAL_ENTRY != 0 { + let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; + let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize; + let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize; + let advance_output_bytes4 = ((litlen_entry4 & 0xf00) >> 8) as usize; + if output_index + + advance_output_bytes + + advance_output_bytes2 + + advance_output_bytes3 + + advance_output_bytes4 + < output.len() + { + self.consume_bits( + litlen_code_bits + + litlen_code_bits2 + + litlen_code_bits3 + + litlen_code_bits4, + ); + + output[output_index] = (litlen_entry >> 16) as u8; + output[output_index + 1] = (litlen_entry >> 24) as u8; + output_index += advance_output_bytes; + output[output_index] = (litlen_entry2 >> 16) as u8; + output[output_index + 1] = (litlen_entry2 >> 24) as u8; + output_index += advance_output_bytes2; + output[output_index] = (litlen_entry3 >> 16) as u8; + output[output_index + 1] = (litlen_entry3 >> 24) as u8; + output_index += advance_output_bytes3; + output[output_index] = (litlen_entry4 >> 16) as u8; + output[output_index + 1] = (litlen_entry4 >> 24) as u8; + output_index += advance_output_bytes4; + continue; + } + } + } + + // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the + // output bytes and we can directly write them to the output buffer. + let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; + + // match advance_output_bytes { + // 1 => println!("[{output_index}] LIT1 {}", litlen_entry >> 16), + // 2 => println!( + // "[{output_index}] LIT2 {} {} {}", + // (litlen_entry >> 16) as u8, + // litlen_entry >> 24, + // bits & 0xfff + // ), + // n => println!( + // "[{output_index}] LIT{n} {} {}", + // (litlen_entry >> 16) as u8, + // litlen_entry >> 24, + // ), + // } + + if self.nbits < litlen_code_bits { + break; + } else if output_index + 1 < output.len() { + output[output_index] = (litlen_entry >> 16) as u8; + output[output_index + 1] = (litlen_entry >> 24) as u8; + output_index += advance_output_bytes; + self.consume_bits(litlen_code_bits); + continue; + } else if output_index + advance_output_bytes == output.len() { + debug_assert_eq!(advance_output_bytes, 1); + output[output_index] = (litlen_entry >> 16) as u8; + output_index += 1; + self.consume_bits(litlen_code_bits); + break; + } else { + debug_assert_eq!(advance_output_bytes, 2); + output[output_index] = (litlen_entry >> 16) as u8; + self.queued_rle = Some(((litlen_entry >> 24) as u8, 1)); + output_index += 1; + self.consume_bits(litlen_code_bits); + break; + } + } + + let (length_base, length_extra_bits, litlen_code_bits) = + if litlen_entry & EXCEPTIONAL_ENTRY == 0 { + ( + litlen_entry >> 16, + (litlen_entry >> 8) as u8, + litlen_code_bits, + ) + } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 { + let secondary_index = litlen_entry >> 16; + let secondary_entry = self.compression.secondary_table + [secondary_index as usize + ((bits >> 12) & 0x7) as usize]; + let litlen_symbol = secondary_entry >> 4; + let litlen_code_bits = (secondary_entry & 0xf) as u8; + + if self.nbits < litlen_code_bits { + break; + } else if litlen_symbol < 256 { + // println!("[{output_index}] LIT1b {} (val={:04x})", litlen_symbol, self.peak_bits(15)); + + self.consume_bits(litlen_code_bits); + output[output_index] = litlen_symbol as u8; + output_index += 1; + continue; + } else if litlen_symbol == 256 { + // println!("[{output_index}] EOF"); + self.consume_bits(litlen_code_bits); + self.state = match self.last_block { + true => State::Checksum, + false => State::BlockHeader, + }; + break; + } + + ( + LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32, + LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257], + litlen_code_bits, + ) + } else if litlen_code_bits == 0 { + return Err(DecompressionError::InvalidLiteralLengthCode); + } else { + if self.nbits < litlen_code_bits { + break; + } + // println!("[{output_index}] EOF"); + self.consume_bits(litlen_code_bits); + self.state = match self.last_block { + true => State::Checksum, + false => State::BlockHeader, + }; + break; + }; + bits >>= litlen_code_bits; + + let length_extra_mask = (1 << length_extra_bits) - 1; + let length = length_base as usize + (bits & length_extra_mask) as usize; + bits >>= length_extra_bits; + + let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize]; + let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry != 0 { + ( + (dist_entry >> 16) as u16, + (dist_entry >> 8) as u8, + dist_entry as u8, + ) + } else { + let mut dist_extra_bits = 0; + let mut dist_base = 0; + let mut dist_advance_bits = 0; + for i in 0..self.compression.dist_symbol_lengths.len() { + if bits as u16 & self.compression.dist_symbol_masks[i] + == self.compression.dist_symbol_codes[i] + { + dist_extra_bits = DIST_SYM_TO_DIST_EXTRA[i]; + dist_base = DIST_SYM_TO_DIST_BASE[i]; + dist_advance_bits = self.compression.dist_symbol_lengths[i]; + break; + } + } + if dist_advance_bits == 0 { + return Err(DecompressionError::InvalidDistanceCode); + } + (dist_base, dist_extra_bits, dist_advance_bits) + }; + bits >>= dist_code_bits; + + let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize; + let total_bits = + litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits; + + if self.nbits < total_bits { + break; + } else if dist > output_index { + return Err(DecompressionError::DistanceTooFarBack); + } + + // println!("[{output_index}] BACKREF len={} dist={} {:x}", length, dist, dist_entry); + self.consume_bits(total_bits); + + let copy_length = length.min(output.len() - output_index); + if dist == 1 { + let last = output[output_index - 1]; + output[output_index..][..copy_length].fill(last); + + if copy_length < length { + self.queued_rle = Some((last, length - copy_length)); + output_index = output.len(); + break; + } + } else if output_index + length + 15 <= output.len() { + let start = output_index - dist; + output.copy_within(start..start + 16, output_index); + + if length > 16 || dist < 16 { + for i in (0..length).step_by(dist.min(16)).skip(1) { + output.copy_within(start + i..start + i + 16, output_index + i); + } + } + } else { + if dist < copy_length { + for i in 0..copy_length { + output[output_index + i] = output[output_index + i - dist]; + } + } else { + output.copy_within( + output_index - dist..output_index + copy_length - dist, + output_index, + ) + } + + if copy_length < length { + self.queued_backref = Some((dist, length - copy_length)); + output_index = output.len(); + break; + } + } + output_index += copy_length; + } + + if self.state == State::CompressedData + && self.queued_backref.is_none() + && self.queued_rle.is_none() + && self.nbits >= 15 + && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code + { + self.consume_bits(self.compression.eof_bits); + self.state = match self.last_block { + true => State::Checksum, + false => State::BlockHeader, + }; + } + + Ok(output_index) + } + + /// Decompresses a chunk of data. + /// + /// Returns the number of bytes read from `input` and the number of bytes written to `output`, + /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is + /// the buffer to write the decompressed data to, starting at index `output_position`. + /// `end_of_input` indicates whether more data may be available in the future. + /// + /// The contents of `output` after `output_position` are ignored. However, this function may + /// write additional data to `output` past what is indicated by the return value. + /// + /// When this function returns `Ok`, at least one of the following is true: + /// - The input is fully consumed. + /// - The output is full but there are more bytes to output. + /// - The deflate stream is complete (and `is_done` will return true). + /// + /// # Panics + /// + /// This function will panic if `output_position` is out of bounds. + pub fn read( + &mut self, + input: &[u8], + output: &mut [u8], + output_position: usize, + end_of_input: bool, + ) -> Result<(usize, usize), DecompressionError> { + if let State::Done = self.state { + return Ok((0, 0)); + } + + assert!(output_position <= output.len()); + + let mut remaining_input = input; + let mut output_index = output_position; + + if let Some((data, len)) = self.queued_rle.take() { + let n = len.min(output.len() - output_index); + output[output_index..][..n].fill(data); + output_index += n; + if n < len { + self.queued_rle = Some((data, len - n)); + return Ok((0, n)); + } + } + if let Some((dist, len)) = self.queued_backref.take() { + let n = len.min(output.len() - output_index); + for i in 0..n { + output[output_index + i] = output[output_index + i - dist]; + } + output_index += n; + if n < len { + self.queued_backref = Some((dist, len - n)); + return Ok((0, n)); + } + } + + // Main decoding state machine. + let mut last_state = None; + while last_state != Some(self.state) { + last_state = Some(self.state); + match self.state { + State::ZlibHeader => { + self.fill_buffer(&mut remaining_input); + if self.nbits < 16 { + break; + } + + let input0 = self.peak_bits(8); + let input1 = self.peak_bits(16) >> 8 & 0xff; + if input0 & 0x0f != 0x08 + || (input0 & 0xf0) > 0x70 + || input1 & 0x20 != 0 + || (input0 << 8 | input1) % 31 != 0 + { + return Err(DecompressionError::BadZlibHeader); + } + + self.consume_bits(16); + self.state = State::BlockHeader; + } + State::BlockHeader => { + self.read_block_header(&mut remaining_input)?; + } + State::CodeLengthCodes => { + self.read_code_length_codes(&mut remaining_input)?; + } + State::CodeLengths => { + self.read_code_lengths(&mut remaining_input)?; + } + State::CompressedData => { + output_index = + self.read_compressed(&mut remaining_input, output, output_index)? + } + State::UncompressedData => { + // Drain any bytes from our buffer. + debug_assert_eq!(self.nbits % 8, 0); + while self.nbits > 0 + && self.uncompressed_bytes_left > 0 + && output_index < output.len() + { + output[output_index] = self.peak_bits(8) as u8; + self.consume_bits(8); + output_index += 1; + self.uncompressed_bytes_left -= 1; + } + // Buffer may contain one additional byte. Clear it to avoid confusion. + if self.nbits == 0 { + self.buffer = 0; + } + + // Copy subsequent bytes directly from the input. + let copy_bytes = (self.uncompressed_bytes_left as usize) + .min(remaining_input.len()) + .min(output.len() - output_index); + output[output_index..][..copy_bytes] + .copy_from_slice(&remaining_input[..copy_bytes]); + remaining_input = &remaining_input[copy_bytes..]; + output_index += copy_bytes; + self.uncompressed_bytes_left -= copy_bytes as u16; + + if self.uncompressed_bytes_left == 0 { + self.state = if self.last_block { + State::Checksum + } else { + State::BlockHeader + }; + } + } + State::Checksum => { + self.fill_buffer(&mut remaining_input); + + let align_bits = self.nbits % 8; + if self.nbits >= 32 + align_bits { + self.checksum.write(&output[output_position..output_index]); + if align_bits != 0 { + self.consume_bits(align_bits); + } + #[cfg(not(fuzzing))] + if !self.ignore_adler32 + && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish() + { + return Err(DecompressionError::WrongChecksum); + } + self.state = State::Done; + self.consume_bits(32); + break; + } + } + State::Done => unreachable!(), + } + } + + if !self.ignore_adler32 && self.state != State::Done { + self.checksum.write(&output[output_position..output_index]); + } + + if self.state == State::Done || !end_of_input || output_index >= output.len() - 1 { + let input_left = remaining_input.len(); + Ok((input.len() - input_left, output_index - output_position)) + } else { + Err(DecompressionError::InsufficientInput) + } + } + + /// Returns true if the decompressor has finished decompressing the input. + pub fn is_done(&self) -> bool { + self.state == State::Done + } +} + +/// Decompress the given data. +pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> { + match decompress_to_vec_bounded(input, usize::MAX) { + Ok(output) => Ok(output), + Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner), + Err(BoundedDecompressionError::OutputTooLarge { .. }) => { + unreachable!("Impossible to allocate more than isize::MAX bytes") + } + } +} + +/// An error encountered while decompressing a deflate stream given a bounded maximum output. +pub enum BoundedDecompressionError { + /// The input is not a valid deflate stream. + DecompressionError { + /// The underlying error. + inner: DecompressionError, + }, + + /// The output is too large. + OutputTooLarge { + /// The output decoded so far. + partial_output: Vec<u8>, + }, +} +impl From<DecompressionError> for BoundedDecompressionError { + fn from(inner: DecompressionError) -> Self { + BoundedDecompressionError::DecompressionError { inner } + } +} + +/// Decompress the given data, returning an error if the output is larger than +/// `maxlen` bytes. +pub fn decompress_to_vec_bounded( + input: &[u8], + maxlen: usize, +) -> Result<Vec<u8>, BoundedDecompressionError> { + let mut decoder = Decompressor::new(); + let mut output = vec![0; 1024.min(maxlen)]; + let mut input_index = 0; + let mut output_index = 0; + loop { + let (consumed, produced) = + decoder.read(&input[input_index..], &mut output, output_index, true)?; + input_index += consumed; + output_index += produced; + if decoder.is_done() || output_index == maxlen { + break; + } + output.resize((output_index + 32 * 1024).min(maxlen), 0); + } + output.resize(output_index, 0); + + if decoder.is_done() { + Ok(output) + } else { + Err(BoundedDecompressionError::OutputTooLarge { + partial_output: output, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tables::{self, LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL}; + + use super::*; + use rand::Rng; + + fn roundtrip(data: &[u8]) { + let compressed = crate::compress_to_vec(data); + let decompressed = decompress_to_vec(&compressed).unwrap(); + assert_eq!(&decompressed, data); + } + + fn roundtrip_miniz_oxide(data: &[u8]) { + let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3); + let decompressed = decompress_to_vec(&compressed).unwrap(); + assert_eq!(decompressed.len(), data.len()); + for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() { + assert_eq!(a, b, "chunk {}..{}", i * 1, i * 1 + 1); + } + assert_eq!(&decompressed, data); + } + + #[allow(unused)] + fn compare_decompression(data: &[u8]) { + // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data)) + // .bytes() + // .collect::<Result<Vec<_>, _>>() + // .unwrap(); + let decompressed = decompress_to_vec(&data).unwrap(); + let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(&data).unwrap(); + for i in 0..decompressed.len().min(decompressed2.len()) { + if decompressed[i] != decompressed2[i] { + panic!( + "mismatch at index {} {:?} {:?}", + i, + &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())], + &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())] + ); + } + } + if decompressed != decompressed2 { + panic!( + "length mismatch {} {} {:x?}", + decompressed.len(), + decompressed2.len(), + &decompressed2[decompressed.len()..][..16] + ); + } + //assert_eq!(decompressed, decompressed2); + } + + #[test] + fn tables() { + for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() { + let len_base = LEN_SYM_TO_LEN_BASE[i]; + for j in 0..(1 << bits) { + if i == 27 && j == 31 { + continue; + } + assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j); + assert_eq!( + LENGTH_TO_SYMBOL[len_base + j - 3], + i as u16 + 257, + "{} {}", + i, + j + ); + } + } + } + + #[test] + fn fdeflate_table() { + let mut compression = CompressedBlock { + litlen_table: [0; 4096], + dist_table: [0; 512], + dist_symbol_lengths: [0; 30], + dist_symbol_masks: [0; 30], + dist_symbol_codes: [0; 30], + secondary_table: Vec::new(), + eof_code: 0, + eof_mask: 0, + eof_bits: 0, + }; + let mut lengths = tables::HUFFMAN_LENGTHS.to_vec(); + lengths.resize(288, 0); + lengths.push(1); + lengths.resize(320, 0); + Decompressor::build_tables(286, &lengths, &mut compression, 11).unwrap(); + + assert_eq!( + compression, FDEFLATE_COMPRESSED_BLOCK, + "{:#x?}", + compression + ); + } + + #[test] + fn it_works() { + roundtrip(b"Hello world!"); + } + + #[test] + fn constant() { + roundtrip_miniz_oxide(&vec![0; 50]); + roundtrip_miniz_oxide(&vec![5; 2048]); + roundtrip_miniz_oxide(&vec![128; 2048]); + roundtrip_miniz_oxide(&vec![254; 2048]); + } + + #[test] + fn random() { + let mut rng = rand::thread_rng(); + let mut data = vec![0; 50000]; + for _ in 0..10 { + for byte in &mut data { + *byte = rng.gen::<u8>() % 5; + } + println!("Random data: {:?}", data); + roundtrip_miniz_oxide(&data); + } + } + + #[test] + fn ignore_adler32() { + let mut compressed = crate::compress_to_vec(b"Hello world!"); + let last_byte = compressed.len() - 1; + compressed[last_byte] = compressed[last_byte].wrapping_add(1); + + match decompress_to_vec(&compressed) { + Err(DecompressionError::WrongChecksum) => {} + r => panic!("expected WrongChecksum, got {:?}", r), + } + + let mut decompressor = Decompressor::new(); + decompressor.ignore_adler32(); + let mut decompressed = vec![0; 1024]; + let decompressed_len = decompressor + .read(&compressed, &mut decompressed, 0, true) + .unwrap() + .1; + assert_eq!(&decompressed[..decompressed_len], b"Hello world!"); + } + + #[test] + fn checksum_after_eof() { + let input = b"Hello world!"; + let compressed = crate::compress_to_vec(input); + + let mut decompressor = Decompressor::new(); + let mut decompressed = vec![0; 1024]; + let (input_consumed, output_written) = decompressor + .read( + &compressed[..compressed.len() - 1], + &mut decompressed, + 0, + false, + ) + .unwrap(); + assert_eq!(output_written, input.len()); + assert_eq!(input_consumed, compressed.len() - 1); + + let (input_consumed, output_written) = decompressor + .read( + &compressed[input_consumed..], + &mut decompressed[..output_written], + output_written, + true, + ) + .unwrap(); + assert!(decompressor.is_done()); + assert_eq!(input_consumed, 1); + assert_eq!(output_written, 0); + + assert_eq!(&decompressed[..input.len()], input); + } + + #[test] + fn zero_length() { + let mut compressed = crate::compress_to_vec(b"").to_vec(); + + // Splice in zero-length non-compressed blocks. + for _ in 0..10 { + println!("compressed len: {}", compressed.len()); + compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter()); + } + + // Ensure that the full input is decompressed, regardless of whether + // `end_of_input` is set. + for end_of_input in [true, false] { + let mut decompressor = Decompressor::new(); + let (input_consumed, output_written) = decompressor + .read(&compressed, &mut [], 0, end_of_input) + .unwrap(); + + assert!(decompressor.is_done()); + assert_eq!(input_consumed, compressed.len()); + assert_eq!(output_written, 0); + } + } +} |