use std::convert::TryInto; use simd_adler32::Adler32; use crate::tables::{ self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FDEFLATE_DIST_DECODE_TABLE, FDEFLATE_LITLEN_DECODE_TABLE, FIXED_CODE_LENGTHS, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, }; /// An error encountered while decompressing a deflate stream. #[derive(Debug, PartialEq)] pub enum DecompressionError { /// The zlib header is corrupt. BadZlibHeader, /// All input was consumed, but the end of the stream hasn't been reached. InsufficientInput, /// A block header specifies an invalid block type. InvalidBlockType, /// An uncompressed block's NLEN value is invalid. InvalidUncompressedBlockLength, /// Too many literals were specified. InvalidHlit, /// Too many distance codes were specified. InvalidHdist, /// Attempted to repeat a previous code before reading any codes, or past the end of the code /// lengths. InvalidCodeLengthRepeat, /// The stream doesn't specify a valid huffman tree. BadCodeLengthHuffmanTree, /// The stream doesn't specify a valid huffman tree. BadLiteralLengthHuffmanTree, /// The stream doesn't specify a valid huffman tree. BadDistanceHuffmanTree, /// The stream contains a literal/length code that was not allowed by the header. InvalidLiteralLengthCode, /// The stream contains a distance code that was not allowed by the header. InvalidDistanceCode, /// The stream contains contains back-reference as the first symbol. InputStartsWithRun, /// The stream contains a back-reference that is too far back. DistanceTooFarBack, /// The deflate stream checksum is incorrect. WrongChecksum, /// Extra input data. ExtraInput, } struct BlockHeader { hlit: usize, hdist: usize, hclen: usize, num_lengths_read: usize, /// Low 3-bits are code length code length, high 5-bits are code length code. table: [u8; 128], code_lengths: [u8; 320], } const LITERAL_ENTRY: u32 = 0x8000; const EXCEPTIONAL_ENTRY: u32 = 0x4000; const SECONDARY_TABLE_ENTRY: u32 = 0x2000; /// The Decompressor state for a compressed block. /// /// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is /// split into 4 sections: /// /// aaaaaaaa_bbbbbbbb_1000yyyy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal) /// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length) /// 00000000_00000000_01000000_0000xxxx x = input_advance_bits (EOF) /// 0000xxxx_xxxxxxxx_01100000_00000000 x = secondary_table_index /// 00000000_00000000_01000000_00000000 invalid code /// /// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning. /// /// 00000000_00000000_00000000_00000000 symbol is more than 9 bits /// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base #[repr(align(64))] #[derive(Eq, PartialEq, Debug)] struct CompressedBlock { litlen_table: [u32; 4096], dist_table: [u32; 512], dist_symbol_lengths: [u8; 30], dist_symbol_masks: [u16; 30], dist_symbol_codes: [u16; 30], secondary_table: Vec, eof_code: u16, eof_mask: u16, eof_bits: u8, } const FDEFLATE_COMPRESSED_BLOCK: CompressedBlock = CompressedBlock { litlen_table: FDEFLATE_LITLEN_DECODE_TABLE, dist_table: FDEFLATE_DIST_DECODE_TABLE, dist_symbol_lengths: [ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], dist_symbol_masks: [ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], dist_symbol_codes: [ 0, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, ], secondary_table: Vec::new(), eof_code: 0x8ff, eof_mask: 0xfff, eof_bits: 0xc, }; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum State { ZlibHeader, BlockHeader, CodeLengthCodes, CodeLengths, CompressedData, UncompressedData, Checksum, Done, } /// Decompressor for arbitrary zlib streams. pub struct Decompressor { /// State for decoding a compressed block. compression: CompressedBlock, // State for decoding a block header. header: BlockHeader, // Number of bytes left for uncompressed block. uncompressed_bytes_left: u16, buffer: u64, nbits: u8, queued_rle: Option<(u8, usize)>, queued_backref: Option<(usize, usize)>, last_block: bool, state: State, checksum: Adler32, ignore_adler32: bool, } impl Default for Decompressor { fn default() -> Self { Self::new() } } impl Decompressor { /// Create a new decompressor. pub fn new() -> Self { Self { buffer: 0, nbits: 0, compression: CompressedBlock { litlen_table: [0; 4096], dist_table: [0; 512], secondary_table: Vec::new(), dist_symbol_lengths: [0; 30], dist_symbol_masks: [0; 30], dist_symbol_codes: [0xffff; 30], eof_code: 0, eof_mask: 0, eof_bits: 0, }, header: BlockHeader { hlit: 0, hdist: 0, hclen: 0, table: [0; 128], num_lengths_read: 0, code_lengths: [0; 320], }, uncompressed_bytes_left: 0, queued_rle: None, queued_backref: None, checksum: Adler32::new(), state: State::ZlibHeader, last_block: false, ignore_adler32: false, } } /// Ignore the checksum at the end of the stream. pub fn ignore_adler32(&mut self) { self.ignore_adler32 = true; } fn fill_buffer(&mut self, input: &mut &[u8]) { if input.len() >= 8 { self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits; *input = &mut &input[(63 - self.nbits as usize) / 8..]; self.nbits |= 56; } else { let nbytes = input.len().min((63 - self.nbits as usize) / 8); let mut input_data = [0; 8]; input_data[..nbytes].copy_from_slice(&input[..nbytes]); self.buffer |= u64::from_le_bytes(input_data) .checked_shl(self.nbits as u32) .unwrap_or(0); self.nbits += nbytes as u8 * 8; *input = &mut &input[nbytes..]; } } fn peak_bits(&mut self, nbits: u8) -> u64 { debug_assert!(nbits <= 56 && nbits <= self.nbits); self.buffer & ((1u64 << nbits) - 1) } fn consume_bits(&mut self, nbits: u8) { debug_assert!(self.nbits >= nbits); self.buffer >>= nbits; self.nbits -= nbits; } fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { self.fill_buffer(remaining_input); if self.nbits < 3 { return Ok(()); } let start = self.peak_bits(3); self.last_block = start & 1 != 0; match start >> 1 { 0b00 => { let align_bits = (self.nbits - 3) % 8; let header_bits = 3 + 32 + align_bits; if self.nbits < header_bits { return Ok(()); } let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16; let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16; if nlen != !len { return Err(DecompressionError::InvalidUncompressedBlockLength); } self.state = State::UncompressedData; self.uncompressed_bytes_left = len; self.consume_bits(header_bits); Ok(()) } 0b01 => { self.consume_bits(3); // TODO: Do this statically rather than every time. Self::build_tables(288, &FIXED_CODE_LENGTHS, &mut self.compression, 6)?; self.state = State::CompressedData; Ok(()) } 0b10 => { if self.nbits < 17 { return Ok(()); } self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257; self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1; self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4; if self.header.hlit > 286 { return Err(DecompressionError::InvalidHlit); } if self.header.hdist > 30 { return Err(DecompressionError::InvalidHdist); } self.consume_bits(17); self.state = State::CodeLengthCodes; Ok(()) } 0b11 => Err(DecompressionError::InvalidBlockType), _ => unreachable!(), } } fn read_code_length_codes( &mut self, remaining_input: &mut &[u8], ) -> Result<(), DecompressionError> { self.fill_buffer(remaining_input); if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen { return Ok(()); } let mut code_length_lengths = [0; 19]; for i in 0..self.header.hclen { code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8; self.consume_bits(3); // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds // between 56 and 63 bits total. if i == 17 { self.fill_buffer(remaining_input); } } let code_length_codes: [u16; 19] = crate::compute_codes(&code_length_lengths) .ok_or(DecompressionError::BadCodeLengthHuffmanTree)?; self.header.table = [255; 128]; for i in 0..19 { let length = code_length_lengths[i]; if length > 0 { let mut j = code_length_codes[i]; while j < 128 { self.header.table[j as usize] = ((i as u8) << 3) | length; j += 1 << length; } } } self.state = State::CodeLengths; self.header.num_lengths_read = 0; Ok(()) } fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> { let total_lengths = self.header.hlit + self.header.hdist; while self.header.num_lengths_read < total_lengths { self.fill_buffer(remaining_input); if self.nbits < 7 { return Ok(()); } let code = self.peak_bits(7); let entry = self.header.table[code as usize]; let length = entry & 0x7; let symbol = entry >> 3; debug_assert!(length != 0); match symbol { 0..=15 => { self.header.code_lengths[self.header.num_lengths_read] = symbol; self.header.num_lengths_read += 1; self.consume_bits(length); } 16..=18 => { let (base_repeat, extra_bits) = match symbol { 16 => (3, 2), 17 => (3, 3), 18 => (11, 7), _ => unreachable!(), }; if self.nbits < length + extra_bits { return Ok(()); } let value = match symbol { 16 => { self.header.code_lengths[self .header .num_lengths_read .checked_sub(1) .ok_or(DecompressionError::InvalidCodeLengthRepeat)?] // TODO: is this right? } 17 => 0, 18 => 0, _ => unreachable!(), }; let repeat = (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat; if self.header.num_lengths_read + repeat > total_lengths { return Err(DecompressionError::InvalidCodeLengthRepeat); } for i in 0..repeat { self.header.code_lengths[self.header.num_lengths_read + i] = value; } self.header.num_lengths_read += repeat; self.consume_bits(length + extra_bits); } _ => unreachable!(), } } self.header .code_lengths .copy_within(self.header.hlit..total_lengths, 288); for i in self.header.hlit..288 { self.header.code_lengths[i] = 0; } for i in 288 + self.header.hdist..320 { self.header.code_lengths[i] = 0; } if self.header.hdist == 1 && self.header.code_lengths[..286] == tables::HUFFMAN_LENGTHS && self.header.code_lengths[288] == 1 { self.compression = FDEFLATE_COMPRESSED_BLOCK; } else { Self::build_tables( self.header.hlit, &self.header.code_lengths, &mut self.compression, 6, )?; } self.state = State::CompressedData; Ok(()) } fn build_tables( hlit: usize, code_lengths: &[u8], compression: &mut CompressedBlock, max_search_bits: u8, ) -> Result<(), DecompressionError> { // Build the literal/length code table. let lengths = &code_lengths[..288]; let codes: [u16; 288] = crate::compute_codes(&lengths.try_into().unwrap()) .ok_or(DecompressionError::BadLiteralLengthHuffmanTree)?; let table_bits = lengths.iter().cloned().max().unwrap().min(12).max(6); let table_size = 1 << table_bits; for i in 0..256 { let code = codes[i]; let length = lengths[i]; let mut j = code; while j < table_size && length != 0 && length <= 12 { compression.litlen_table[j as usize] = ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8) | length as u32; j += 1 << length; } if length > 0 && length <= max_search_bits { for ii in 0..256 { let code2 = codes[ii]; let length2 = lengths[ii]; if length2 != 0 && length + length2 <= table_bits { let mut j = code | (code2 << length); while j < table_size { compression.litlen_table[j as usize] = (ii as u32) << 24 | (i as u32) << 16 | LITERAL_ENTRY | (2 << 8) | ((length + length2) as u32); j += 1 << (length + length2); } } } } } if lengths[256] != 0 && lengths[256] <= 12 { let mut j = codes[256]; while j < table_size { compression.litlen_table[j as usize] = EXCEPTIONAL_ENTRY | lengths[256] as u32; j += 1 << lengths[256]; } } let table_size = table_size as usize; for i in (table_size..4096).step_by(table_size) { compression.litlen_table.copy_within(0..table_size, i); } compression.eof_code = codes[256]; compression.eof_mask = (1 << lengths[256]) - 1; compression.eof_bits = lengths[256]; for i in 257..hlit { let code = codes[i]; let length = lengths[i]; if length != 0 && length <= 12 { let mut j = code; while j < 4096 { compression.litlen_table[j as usize] = if i < 286 { (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16 | (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8 | length as u32 } else { EXCEPTIONAL_ENTRY }; j += 1 << length; } } } for i in 0..hlit { if lengths[i] > 12 { compression.litlen_table[(codes[i] & 0xfff) as usize] = u32::MAX; } } let mut secondary_table_len = 0; for i in 0..hlit { if lengths[i] > 12 { let j = (codes[i] & 0xfff) as usize; if compression.litlen_table[j] == u32::MAX { compression.litlen_table[j] = (secondary_table_len << 16) | EXCEPTIONAL_ENTRY | SECONDARY_TABLE_ENTRY; secondary_table_len += 8; } } } assert!(secondary_table_len <= 0x7ff); compression.secondary_table = vec![0; secondary_table_len as usize]; for i in 0..hlit { let code = codes[i]; let length = lengths[i]; if length > 12 { let j = (codes[i] & 0xfff) as usize; let k = (compression.litlen_table[j] >> 16) as usize; let mut s = code >> 12; while s < 8 { debug_assert_eq!(compression.secondary_table[k + s as usize], 0); compression.secondary_table[k + s as usize] = ((i as u16) << 4) | (length as u16); s += 1 << (length - 12); } } } debug_assert!(compression .secondary_table .iter() .all(|&x| x != 0 && (x & 0xf) > 12)); // Build the distance code table. let lengths = &code_lengths[288..320]; if lengths == [0; 32] { compression.dist_symbol_masks = [0; 30]; compression.dist_symbol_codes = [0xffff; 30]; compression.dist_table.fill(0); } else { let codes: [u16; 32] = match crate::compute_codes(&lengths.try_into().unwrap()) { Some(codes) => codes, None => { if lengths.iter().filter(|&&l| l != 0).count() != 1 { return Err(DecompressionError::BadDistanceHuffmanTree); } [0; 32] } }; compression.dist_symbol_codes.copy_from_slice(&codes[..30]); compression .dist_symbol_lengths .copy_from_slice(&lengths[..30]); compression.dist_table.fill(0); for i in 0..30 { let length = lengths[i]; let code = codes[i]; if length == 0 { compression.dist_symbol_masks[i] = 0; compression.dist_symbol_codes[i] = 0xffff; } else { compression.dist_symbol_masks[i] = (1 << lengths[i]) - 1; if lengths[i] <= 9 { let mut j = code; while j < 512 { compression.dist_table[j as usize] = (DIST_SYM_TO_DIST_BASE[i] as u32) << 16 | (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8 | length as u32; j += 1 << lengths[i]; } } } } } Ok(()) } fn read_compressed( &mut self, remaining_input: &mut &[u8], output: &mut [u8], mut output_index: usize, ) -> Result { while let State::CompressedData = self.state { self.fill_buffer(remaining_input); if output_index == output.len() { break; } let mut bits = self.buffer; let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize]; let litlen_code_bits = litlen_entry as u8; if litlen_entry & LITERAL_ENTRY != 0 { // Ultra-fast path: do 3 more consecutive table lookups and bail if any of them need the slow path. if self.nbits >= 48 { let litlen_entry2 = self.compression.litlen_table[(bits >> litlen_code_bits & 0xfff) as usize]; let litlen_code_bits2 = litlen_entry2 as u8; let litlen_entry3 = self.compression.litlen_table [(bits >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize]; let litlen_code_bits3 = litlen_entry3 as u8; let litlen_entry4 = self.compression.litlen_table[(bits >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3) & 0xfff) as usize]; let litlen_code_bits4 = litlen_entry4 as u8; if litlen_entry2 & litlen_entry3 & litlen_entry4 & LITERAL_ENTRY != 0 { let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize; let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize; let advance_output_bytes4 = ((litlen_entry4 & 0xf00) >> 8) as usize; if output_index + advance_output_bytes + advance_output_bytes2 + advance_output_bytes3 + advance_output_bytes4 < output.len() { self.consume_bits( litlen_code_bits + litlen_code_bits2 + litlen_code_bits3 + litlen_code_bits4, ); output[output_index] = (litlen_entry >> 16) as u8; output[output_index + 1] = (litlen_entry >> 24) as u8; output_index += advance_output_bytes; output[output_index] = (litlen_entry2 >> 16) as u8; output[output_index + 1] = (litlen_entry2 >> 24) as u8; output_index += advance_output_bytes2; output[output_index] = (litlen_entry3 >> 16) as u8; output[output_index + 1] = (litlen_entry3 >> 24) as u8; output_index += advance_output_bytes3; output[output_index] = (litlen_entry4 >> 16) as u8; output[output_index + 1] = (litlen_entry4 >> 24) as u8; output_index += advance_output_bytes4; continue; } } } // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the // output bytes and we can directly write them to the output buffer. let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; // match advance_output_bytes { // 1 => println!("[{output_index}] LIT1 {}", litlen_entry >> 16), // 2 => println!( // "[{output_index}] LIT2 {} {} {}", // (litlen_entry >> 16) as u8, // litlen_entry >> 24, // bits & 0xfff // ), // n => println!( // "[{output_index}] LIT{n} {} {}", // (litlen_entry >> 16) as u8, // litlen_entry >> 24, // ), // } if self.nbits < litlen_code_bits { break; } else if output_index + 1 < output.len() { output[output_index] = (litlen_entry >> 16) as u8; output[output_index + 1] = (litlen_entry >> 24) as u8; output_index += advance_output_bytes; self.consume_bits(litlen_code_bits); continue; } else if output_index + advance_output_bytes == output.len() { debug_assert_eq!(advance_output_bytes, 1); output[output_index] = (litlen_entry >> 16) as u8; output_index += 1; self.consume_bits(litlen_code_bits); break; } else { debug_assert_eq!(advance_output_bytes, 2); output[output_index] = (litlen_entry >> 16) as u8; self.queued_rle = Some(((litlen_entry >> 24) as u8, 1)); output_index += 1; self.consume_bits(litlen_code_bits); break; } } let (length_base, length_extra_bits, litlen_code_bits) = if litlen_entry & EXCEPTIONAL_ENTRY == 0 { ( litlen_entry >> 16, (litlen_entry >> 8) as u8, litlen_code_bits, ) } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 { let secondary_index = litlen_entry >> 16; let secondary_entry = self.compression.secondary_table [secondary_index as usize + ((bits >> 12) & 0x7) as usize]; let litlen_symbol = secondary_entry >> 4; let litlen_code_bits = (secondary_entry & 0xf) as u8; if self.nbits < litlen_code_bits { break; } else if litlen_symbol < 256 { // println!("[{output_index}] LIT1b {} (val={:04x})", litlen_symbol, self.peak_bits(15)); self.consume_bits(litlen_code_bits); output[output_index] = litlen_symbol as u8; output_index += 1; continue; } else if litlen_symbol == 256 { // println!("[{output_index}] EOF"); self.consume_bits(litlen_code_bits); self.state = match self.last_block { true => State::Checksum, false => State::BlockHeader, }; break; } ( LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32, LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257], litlen_code_bits, ) } else if litlen_code_bits == 0 { return Err(DecompressionError::InvalidLiteralLengthCode); } else { if self.nbits < litlen_code_bits { break; } // println!("[{output_index}] EOF"); self.consume_bits(litlen_code_bits); self.state = match self.last_block { true => State::Checksum, false => State::BlockHeader, }; break; }; bits >>= litlen_code_bits; let length_extra_mask = (1 << length_extra_bits) - 1; let length = length_base as usize + (bits & length_extra_mask) as usize; bits >>= length_extra_bits; let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize]; let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry != 0 { ( (dist_entry >> 16) as u16, (dist_entry >> 8) as u8, dist_entry as u8, ) } else { let mut dist_extra_bits = 0; let mut dist_base = 0; let mut dist_advance_bits = 0; for i in 0..self.compression.dist_symbol_lengths.len() { if bits as u16 & self.compression.dist_symbol_masks[i] == self.compression.dist_symbol_codes[i] { dist_extra_bits = DIST_SYM_TO_DIST_EXTRA[i]; dist_base = DIST_SYM_TO_DIST_BASE[i]; dist_advance_bits = self.compression.dist_symbol_lengths[i]; break; } } if dist_advance_bits == 0 { return Err(DecompressionError::InvalidDistanceCode); } (dist_base, dist_extra_bits, dist_advance_bits) }; bits >>= dist_code_bits; let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize; let total_bits = litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits; if self.nbits < total_bits { break; } else if dist > output_index { return Err(DecompressionError::DistanceTooFarBack); } // println!("[{output_index}] BACKREF len={} dist={} {:x}", length, dist, dist_entry); self.consume_bits(total_bits); let copy_length = length.min(output.len() - output_index); if dist == 1 { let last = output[output_index - 1]; output[output_index..][..copy_length].fill(last); if copy_length < length { self.queued_rle = Some((last, length - copy_length)); output_index = output.len(); break; } } else if output_index + length + 15 <= output.len() { let start = output_index - dist; output.copy_within(start..start + 16, output_index); if length > 16 || dist < 16 { for i in (0..length).step_by(dist.min(16)).skip(1) { output.copy_within(start + i..start + i + 16, output_index + i); } } } else { if dist < copy_length { for i in 0..copy_length { output[output_index + i] = output[output_index + i - dist]; } } else { output.copy_within( output_index - dist..output_index + copy_length - dist, output_index, ) } if copy_length < length { self.queued_backref = Some((dist, length - copy_length)); output_index = output.len(); break; } } output_index += copy_length; } if self.state == State::CompressedData && self.queued_backref.is_none() && self.queued_rle.is_none() && self.nbits >= 15 && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code { self.consume_bits(self.compression.eof_bits); self.state = match self.last_block { true => State::Checksum, false => State::BlockHeader, }; } Ok(output_index) } /// Decompresses a chunk of data. /// /// Returns the number of bytes read from `input` and the number of bytes written to `output`, /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is /// the buffer to write the decompressed data to, starting at index `output_position`. /// `end_of_input` indicates whether more data may be available in the future. /// /// The contents of `output` after `output_position` are ignored. However, this function may /// write additional data to `output` past what is indicated by the return value. /// /// When this function returns `Ok`, at least one of the following is true: /// - The input is fully consumed. /// - The output is full but there are more bytes to output. /// - The deflate stream is complete (and `is_done` will return true). /// /// # Panics /// /// This function will panic if `output_position` is out of bounds. pub fn read( &mut self, input: &[u8], output: &mut [u8], output_position: usize, end_of_input: bool, ) -> Result<(usize, usize), DecompressionError> { if let State::Done = self.state { return Ok((0, 0)); } assert!(output_position <= output.len()); let mut remaining_input = input; let mut output_index = output_position; if let Some((data, len)) = self.queued_rle.take() { let n = len.min(output.len() - output_index); output[output_index..][..n].fill(data); output_index += n; if n < len { self.queued_rle = Some((data, len - n)); return Ok((0, n)); } } if let Some((dist, len)) = self.queued_backref.take() { let n = len.min(output.len() - output_index); for i in 0..n { output[output_index + i] = output[output_index + i - dist]; } output_index += n; if n < len { self.queued_backref = Some((dist, len - n)); return Ok((0, n)); } } // Main decoding state machine. let mut last_state = None; while last_state != Some(self.state) { last_state = Some(self.state); match self.state { State::ZlibHeader => { self.fill_buffer(&mut remaining_input); if self.nbits < 16 { break; } let input0 = self.peak_bits(8); let input1 = self.peak_bits(16) >> 8 & 0xff; if input0 & 0x0f != 0x08 || (input0 & 0xf0) > 0x70 || input1 & 0x20 != 0 || (input0 << 8 | input1) % 31 != 0 { return Err(DecompressionError::BadZlibHeader); } self.consume_bits(16); self.state = State::BlockHeader; } State::BlockHeader => { self.read_block_header(&mut remaining_input)?; } State::CodeLengthCodes => { self.read_code_length_codes(&mut remaining_input)?; } State::CodeLengths => { self.read_code_lengths(&mut remaining_input)?; } State::CompressedData => { output_index = self.read_compressed(&mut remaining_input, output, output_index)? } State::UncompressedData => { // Drain any bytes from our buffer. debug_assert_eq!(self.nbits % 8, 0); while self.nbits > 0 && self.uncompressed_bytes_left > 0 && output_index < output.len() { output[output_index] = self.peak_bits(8) as u8; self.consume_bits(8); output_index += 1; self.uncompressed_bytes_left -= 1; } // Buffer may contain one additional byte. Clear it to avoid confusion. if self.nbits == 0 { self.buffer = 0; } // Copy subsequent bytes directly from the input. let copy_bytes = (self.uncompressed_bytes_left as usize) .min(remaining_input.len()) .min(output.len() - output_index); output[output_index..][..copy_bytes] .copy_from_slice(&remaining_input[..copy_bytes]); remaining_input = &remaining_input[copy_bytes..]; output_index += copy_bytes; self.uncompressed_bytes_left -= copy_bytes as u16; if self.uncompressed_bytes_left == 0 { self.state = if self.last_block { State::Checksum } else { State::BlockHeader }; } } State::Checksum => { self.fill_buffer(&mut remaining_input); let align_bits = self.nbits % 8; if self.nbits >= 32 + align_bits { self.checksum.write(&output[output_position..output_index]); if align_bits != 0 { self.consume_bits(align_bits); } #[cfg(not(fuzzing))] if !self.ignore_adler32 && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish() { return Err(DecompressionError::WrongChecksum); } self.state = State::Done; self.consume_bits(32); break; } } State::Done => unreachable!(), } } if !self.ignore_adler32 && self.state != State::Done { self.checksum.write(&output[output_position..output_index]); } if self.state == State::Done || !end_of_input || output_index >= output.len() - 1 { let input_left = remaining_input.len(); Ok((input.len() - input_left, output_index - output_position)) } else { Err(DecompressionError::InsufficientInput) } } /// Returns true if the decompressor has finished decompressing the input. pub fn is_done(&self) -> bool { self.state == State::Done } } /// Decompress the given data. pub fn decompress_to_vec(input: &[u8]) -> Result, DecompressionError> { match decompress_to_vec_bounded(input, usize::MAX) { Ok(output) => Ok(output), Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner), Err(BoundedDecompressionError::OutputTooLarge { .. }) => { unreachable!("Impossible to allocate more than isize::MAX bytes") } } } /// An error encountered while decompressing a deflate stream given a bounded maximum output. pub enum BoundedDecompressionError { /// The input is not a valid deflate stream. DecompressionError { /// The underlying error. inner: DecompressionError, }, /// The output is too large. OutputTooLarge { /// The output decoded so far. partial_output: Vec, }, } impl From for BoundedDecompressionError { fn from(inner: DecompressionError) -> Self { BoundedDecompressionError::DecompressionError { inner } } } /// Decompress the given data, returning an error if the output is larger than /// `maxlen` bytes. pub fn decompress_to_vec_bounded( input: &[u8], maxlen: usize, ) -> Result, BoundedDecompressionError> { let mut decoder = Decompressor::new(); let mut output = vec![0; 1024.min(maxlen)]; let mut input_index = 0; let mut output_index = 0; loop { let (consumed, produced) = decoder.read(&input[input_index..], &mut output, output_index, true)?; input_index += consumed; output_index += produced; if decoder.is_done() || output_index == maxlen { break; } output.resize((output_index + 32 * 1024).min(maxlen), 0); } output.resize(output_index, 0); if decoder.is_done() { Ok(output) } else { Err(BoundedDecompressionError::OutputTooLarge { partial_output: output, }) } } #[cfg(test)] mod tests { use crate::tables::{self, LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL}; use super::*; use rand::Rng; fn roundtrip(data: &[u8]) { let compressed = crate::compress_to_vec(data); let decompressed = decompress_to_vec(&compressed).unwrap(); assert_eq!(&decompressed, data); } fn roundtrip_miniz_oxide(data: &[u8]) { let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3); let decompressed = decompress_to_vec(&compressed).unwrap(); assert_eq!(decompressed.len(), data.len()); for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() { assert_eq!(a, b, "chunk {}..{}", i * 1, i * 1 + 1); } assert_eq!(&decompressed, data); } #[allow(unused)] fn compare_decompression(data: &[u8]) { // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data)) // .bytes() // .collect::, _>>() // .unwrap(); let decompressed = decompress_to_vec(&data).unwrap(); let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(&data).unwrap(); for i in 0..decompressed.len().min(decompressed2.len()) { if decompressed[i] != decompressed2[i] { panic!( "mismatch at index {} {:?} {:?}", i, &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())], &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())] ); } } if decompressed != decompressed2 { panic!( "length mismatch {} {} {:x?}", decompressed.len(), decompressed2.len(), &decompressed2[decompressed.len()..][..16] ); } //assert_eq!(decompressed, decompressed2); } #[test] fn tables() { for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() { let len_base = LEN_SYM_TO_LEN_BASE[i]; for j in 0..(1 << bits) { if i == 27 && j == 31 { continue; } assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j); assert_eq!( LENGTH_TO_SYMBOL[len_base + j - 3], i as u16 + 257, "{} {}", i, j ); } } } #[test] fn fdeflate_table() { let mut compression = CompressedBlock { litlen_table: [0; 4096], dist_table: [0; 512], dist_symbol_lengths: [0; 30], dist_symbol_masks: [0; 30], dist_symbol_codes: [0; 30], secondary_table: Vec::new(), eof_code: 0, eof_mask: 0, eof_bits: 0, }; let mut lengths = tables::HUFFMAN_LENGTHS.to_vec(); lengths.resize(288, 0); lengths.push(1); lengths.resize(320, 0); Decompressor::build_tables(286, &lengths, &mut compression, 11).unwrap(); assert_eq!( compression, FDEFLATE_COMPRESSED_BLOCK, "{:#x?}", compression ); } #[test] fn it_works() { roundtrip(b"Hello world!"); } #[test] fn constant() { roundtrip_miniz_oxide(&vec![0; 50]); roundtrip_miniz_oxide(&vec![5; 2048]); roundtrip_miniz_oxide(&vec![128; 2048]); roundtrip_miniz_oxide(&vec![254; 2048]); } #[test] fn random() { let mut rng = rand::thread_rng(); let mut data = vec![0; 50000]; for _ in 0..10 { for byte in &mut data { *byte = rng.gen::() % 5; } println!("Random data: {:?}", data); roundtrip_miniz_oxide(&data); } } #[test] fn ignore_adler32() { let mut compressed = crate::compress_to_vec(b"Hello world!"); let last_byte = compressed.len() - 1; compressed[last_byte] = compressed[last_byte].wrapping_add(1); match decompress_to_vec(&compressed) { Err(DecompressionError::WrongChecksum) => {} r => panic!("expected WrongChecksum, got {:?}", r), } let mut decompressor = Decompressor::new(); decompressor.ignore_adler32(); let mut decompressed = vec![0; 1024]; let decompressed_len = decompressor .read(&compressed, &mut decompressed, 0, true) .unwrap() .1; assert_eq!(&decompressed[..decompressed_len], b"Hello world!"); } #[test] fn checksum_after_eof() { let input = b"Hello world!"; let compressed = crate::compress_to_vec(input); let mut decompressor = Decompressor::new(); let mut decompressed = vec![0; 1024]; let (input_consumed, output_written) = decompressor .read( &compressed[..compressed.len() - 1], &mut decompressed, 0, false, ) .unwrap(); assert_eq!(output_written, input.len()); assert_eq!(input_consumed, compressed.len() - 1); let (input_consumed, output_written) = decompressor .read( &compressed[input_consumed..], &mut decompressed[..output_written], output_written, true, ) .unwrap(); assert!(decompressor.is_done()); assert_eq!(input_consumed, 1); assert_eq!(output_written, 0); assert_eq!(&decompressed[..input.len()], input); } #[test] fn zero_length() { let mut compressed = crate::compress_to_vec(b"").to_vec(); // Splice in zero-length non-compressed blocks. for _ in 0..10 { println!("compressed len: {}", compressed.len()); compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter()); } // Ensure that the full input is decompressed, regardless of whether // `end_of_input` is set. for end_of_input in [true, false] { let mut decompressor = Decompressor::new(); let (input_consumed, output_written) = decompressor .read(&compressed, &mut [], 0, end_of_input) .unwrap(); assert!(decompressor.is_done()); assert_eq!(input_consumed, compressed.len()); assert_eq!(output_written, 0); } } }