inflate.c (16393B)
1 #include <stdlib.h> 2 #include <string.h> 3 #include "flate.h" 4 5 enum { 6 CodeBits = 16, /* max number of bits in a code + 1 */ 7 LitlenTableBits = 9, /* litlen code bits used in lookup table */ 8 DistTableBits = 6, /* dist code bits used in lookup table */ 9 ClenTableBits = 6, /* clen code bits used in lookup table */ 10 TableBits = LitlenTableBits, /* log2(lookup table size) */ 11 Nlit = 256, /* number of lit codes */ 12 Nlen = 29, /* number of len codes */ 13 Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */ 14 Ndist = 30, /* number of distance codes */ 15 Nclen = 19, /* number of code length codes */ 16 WinSize = 1 << 15 /* output window size */ 17 }; 18 19 /* states */ 20 enum { 21 BlockHead, 22 UncompressedBlock, 23 CopyUncompressed, 24 FixedHuff, 25 DynamicHuff, 26 DynamicHuffClen, 27 DynamicHuffLitlenDist, 28 DynamicHuffContinue, 29 DecodeBlock, 30 DecodeBlockLenBits, 31 DecodeBlockDist, 32 DecodeBlockDistBits, 33 DecodeBlockCopy 34 }; 35 36 typedef struct { 37 short len; /* code length */ 38 ushort sym; /* symbol */ 39 } Entry; 40 41 /* huffman code tree */ 42 typedef struct { 43 Entry table[1 << TableBits]; /* prefix lookup table */ 44 uint nbits; /* prefix length (table size is 1 << nbits) */ 45 uint sum; /* full codes in table: sum(count[0..nbits]) */ 46 ushort count[CodeBits]; /* number of codes with given length */ 47 ushort symbol[Nlitlen]; /* symbols ordered by code length (lexic.) */ 48 } Huff; 49 50 typedef struct { 51 uchar *src; /* input buffer pointer */ 52 uchar *srcend; 53 54 uint bits; 55 uint nbits; 56 57 uchar win[WinSize]; /* output window */ 58 uint pos; /* window pos */ 59 uint posout; /* used for flushing win */ 60 61 int state; /* decode state */ 62 int final; /* last block flag */ 63 char *err; /* TODO: error message */ 64 65 /* for decoding dynamic code trees in inflate() */ 66 int nlit; 67 int ndist; 68 int nclen; /* also used in decode_block() */ 69 int lenpos; /* also used in decode_block() */ 70 uchar lens[Nlitlen + Ndist]; 71 72 int fixed; /* fixed code tree flag */ 73 Huff lhuff; /* dynamic lit/len huffman code tree */ 74 Huff dhuff; /* dynamic distance huffman code tree */ 75 } State; 76 77 /* TODO: globals.. initialization is not thread safe */ 78 static Huff lhuff; /* fixed lit/len huffman code tree */ 79 static Huff dhuff; /* fixed distance huffman code tree */ 80 81 /* base offset and extra bits tables */ 82 static uchar lenbits[Nlen] = { 83 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 84 }; 85 static ushort lenbase[Nlen] = { 86 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 87 }; 88 static uchar distbits[Ndist] = { 89 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 90 }; 91 static ushort distbase[Ndist] = { 92 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 93 }; 94 95 /* ordering of code lengths */ 96 static uchar clenorder[Nclen] = { 97 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 98 }; 99 100 /* TODO: this or normal inc + reverse() */ 101 /* increment bitwise reversed n (msb is bit 0, lsb is bit len-1) */ 102 static uint revinc(uint n, uint len) { 103 uint i = 1 << (len - 1); 104 105 while (n & i) 106 i >>= 1; 107 if (i) { 108 n &= i - 1; 109 n |= i; 110 } else 111 n = 0; 112 return n; 113 } 114 115 /* build huffman code tree from code lengths (each should be < CodeBits) */ 116 static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) { 117 int offs[CodeBits]; 118 int left; 119 uint i, c, sum, code, len, min, max; 120 ushort *count = huff->count; 121 ushort *symbol = huff->symbol; 122 Entry *table = huff->table; 123 Entry entry; 124 125 /* count code lengths */ 126 for (i = 0; i < CodeBits; i++) 127 count[i] = 0; 128 for (i = 0; i < n; i++) 129 count[lens[i]]++; 130 if (count[0] == n) { 131 huff->nbits = table[0].len = 0; 132 return 0; 133 } 134 count[0] = 0; 135 136 /* bound code lengths, force nbits to be within the bounds */ 137 for (max = CodeBits - 1; max > 0; max--) 138 if (count[max] != 0) 139 break; 140 if (nbits > max) 141 nbits = max; 142 for (min = 1; min < CodeBits; min++) 143 if (count[min] != 0) 144 break; 145 if (nbits < min) { 146 nbits = min; 147 if (nbits > TableBits) 148 return -1; 149 } 150 huff->nbits = nbits; 151 152 /* check if length is over-subscribed or incomplete */ 153 for (left = 1 << min, i = min; i <= max; left <<= 1, i++) { 154 left -= count[i]; 155 /* left < 0: over-subscribed, left > 0: incomplete */ 156 if (left < 0) 157 return -1; 158 } 159 160 for (sum = 0, i = 0; i <= max; i++) { 161 offs[i] = sum; 162 sum += count[i]; 163 } 164 /* needed for decoding codes longer than nbits */ 165 if (nbits < max) 166 huff->sum = offs[nbits + 1]; 167 168 /* sort symbols by code length (lexicographic order) */ 169 for (i = 0; i < n; i++) 170 if (lens[i]) 171 symbol[offs[lens[i]]++] = i; 172 173 /* lookup table for decoding nbits from input.. */ 174 for (i = 0; i < 1 << nbits; i++) 175 table[i].len = table[i].sym = 0; 176 code = 0; 177 /* ..if code is at most nbits (bits are in reverse order, sigh..) */ 178 for (len = min; len <= nbits; len++) 179 for (c = count[len]; c > 0; c--) { 180 entry.len = len; 181 entry.sym = *symbol; 182 for (i = code; i < 1 << nbits; i += 1 << len) 183 table[i] = entry; 184 /* next code */ 185 symbol++; 186 code = revinc(code, len); 187 } 188 /* ..if code is longer than nbits: values for simple bitwise decode */ 189 for (i = 0; code; i++) { 190 table[code].len = -1; 191 table[code].sym = i << 1; 192 code = revinc(code, nbits); 193 } 194 return 0; 195 } 196 197 /* fixed huffman code trees (should be done at compile time..) */ 198 static void init_fixed_huffs(void) { 199 int i; 200 uchar lens[Nlitlen]; 201 202 for (i = 0; i < 144; i++) 203 lens[i] = 8; 204 for (; i < 256; i++) 205 lens[i] = 9; 206 for (; i < 280; i++) 207 lens[i] = 7; 208 for (; i < Nlitlen; i++) 209 lens[i] = 8; 210 build_huff(&lhuff, lens, Nlitlen, 8); 211 212 for (i = 0; i < Ndist; i++) 213 lens[i] = 5; 214 build_huff(&dhuff, lens, Ndist, 5); 215 } 216 217 /* fill *bits with n bits from *src */ 218 static int fillbits_fast(uchar **src, uchar *srcend, uint *bits, uint *nbits, uint n) { 219 while (*nbits < n) { 220 if (*src == srcend) 221 return 0; 222 *bits |= *(*src)++ << *nbits; 223 *nbits += 8; 224 } 225 return 1; 226 } 227 228 /* get n bits from *bits */ 229 static uint getbits_fast(uint *bits, uint *nbits, int n) { 230 uint k; 231 232 k = *bits & ((1 << n) - 1); 233 *bits >>= n; 234 *nbits -= n; 235 return k; 236 } 237 238 static int fillbits(State *s, uint n) { 239 return fillbits_fast(&s->src, s->srcend, &s->bits, &s->nbits, n); 240 } 241 242 static uint getbits(State *s, uint n) { 243 return getbits_fast(&s->bits, &s->nbits, n); 244 } 245 246 /* decode symbol bitwise if code is longer than huffbits */ 247 static uint decode_symbol_long(State *s, Huff *huff, uint bits, uint nbits, int cur) { 248 int sum = huff->sum; 249 uint huffbits = huff->nbits; 250 ushort *count = huff->count + huffbits + 1; 251 252 /* get bits if we are near the end */ 253 if (s->src + 2 >= s->srcend) { 254 while (nbits < CodeBits - 1 && s->src < s->srcend) { 255 bits |= *s->src++ << nbits; 256 nbits += 8; 257 } 258 s->bits = bits; 259 s->nbits = nbits; 260 } 261 bits >>= huffbits; 262 nbits -= huffbits; 263 for (;;) { 264 if (!nbits--) { 265 if (s->src == s->srcend) 266 return FlateIn; 267 bits = *s->src++; 268 nbits = 7; 269 } 270 cur |= bits & 1; 271 bits >>= 1; 272 sum += *count; 273 cur -= *count; 274 if (cur < 0) 275 break; 276 cur <<= 1; 277 count++; 278 if (count == huff->count + CodeBits) 279 return s->err = "symbol decoding failed.", FlateErr; 280 } 281 s->bits = bits; 282 s->nbits = nbits; 283 return huff->symbol[sum + cur]; 284 } 285 286 /* decode a symbol from stream with huff code */ 287 static uint decode_symbol(State *s, Huff *huff) { 288 uint huffbits = huff->nbits; 289 uint nbits = s->nbits; 290 uint bits = s->bits; 291 uint mask = (1 << huffbits) - 1; 292 Entry entry; 293 294 /* get enough bits efficiently */ 295 if (nbits < huffbits) { 296 uchar *src = s->src; 297 298 if (src + 2 < s->srcend) { 299 /* we assume huffbits <= 9 */ 300 bits |= *src++ << nbits; 301 nbits += 8; 302 bits |= *src++ << nbits; 303 nbits += 8; 304 bits |= *src++ << nbits; 305 nbits += 8; 306 s->src = src; 307 } else /* rare */ 308 do { 309 if (s->src == s->srcend) { 310 entry = huff->table[bits & mask]; 311 if (entry.len > 0 && entry.len <= nbits) { 312 s->bits = bits >> entry.len; 313 s->nbits = nbits - entry.len; 314 return entry.sym; 315 } 316 s->bits = bits; 317 s->nbits = nbits; 318 return FlateIn; 319 } 320 bits |= *s->src++ << nbits; 321 nbits += 8; 322 } while (nbits < huffbits); 323 } 324 /* decode bits */ 325 entry = huff->table[bits & mask]; 326 if (entry.len > 0) { 327 s->bits = bits >> entry.len; 328 s->nbits = nbits - entry.len; 329 return entry.sym; 330 } else if (entry.len == 0) 331 return s->err = "symbol decoding failed.", FlateErr; 332 return decode_symbol_long(s, huff, bits, nbits, entry.sym); 333 } 334 335 /* decode a block of data from stream with trees */ 336 static int decode_block(State *s, Huff *lhuff, Huff *dhuff) { 337 uchar *win = s->win; 338 uint pos = s->pos; 339 uint sym = s->nclen; 340 uint len = s->lenpos; 341 uint dist = s->nclen; 342 343 switch (s->state) { 344 case DecodeBlock: 345 for (;;) { 346 sym = decode_symbol(s, lhuff); 347 if (sym < 256) { 348 win[pos++] = sym; 349 if (pos == WinSize) { 350 s->pos = WinSize; 351 s->state = DecodeBlock; 352 return FlateOut; 353 } 354 } else if (sym > 256) { 355 sym -= 257; 356 if (sym >= Nlen) { 357 s->pos = pos; 358 s->state = DecodeBlock; 359 if (sym + 257 == (uint)FlateIn) 360 return FlateIn; 361 return FlateErr; 362 } 363 case DecodeBlockLenBits: 364 if (!fillbits_fast(&s->src, s->srcend, &s->bits, &s->nbits, lenbits[sym])) { 365 s->nclen = sym; /* using nclen to store sym */ 366 s->pos = pos; 367 s->state = DecodeBlockLenBits; 368 return FlateIn; 369 } 370 len = lenbase[sym] + getbits_fast(&s->bits, &s->nbits, lenbits[sym]); 371 case DecodeBlockDist: 372 sym = decode_symbol(s, dhuff); 373 if (sym == (uint)FlateIn) { 374 s->pos = pos; 375 s->lenpos = len; 376 s->state = DecodeBlockDist; 377 return FlateIn; 378 } 379 if (sym >= Ndist) 380 return FlateErr; 381 case DecodeBlockDistBits: 382 if (!fillbits_fast(&s->src, s->srcend, &s->bits, &s->nbits, distbits[sym])) { 383 s->nclen = sym; /* using nclen to store sym */ 384 s->pos = pos; 385 s->lenpos = len; 386 s->state = DecodeBlockDistBits; 387 return FlateIn; 388 } 389 dist = distbase[sym] + getbits_fast(&s->bits, &s->nbits, distbits[sym]); 390 /* copy match, loop unroll in common case */ 391 if (pos + len < WinSize) { 392 /* lenbase[sym] >= 3 */ 393 do { 394 win[pos] = win[(pos - dist) % WinSize]; 395 pos++; 396 win[pos] = win[(pos - dist) % WinSize]; 397 pos++; 398 win[pos] = win[(pos - dist) % WinSize]; 399 pos++; 400 len -= 3; 401 } while (len >= 3); 402 if (len--) { 403 win[pos] = win[(pos - dist) % WinSize]; 404 pos++; 405 if (len) { 406 win[pos] = win[(pos - dist) % WinSize]; 407 pos++; 408 } 409 } 410 } else { /* rare */ 411 case DecodeBlockCopy: 412 while (len--) { 413 win[pos] = win[(pos - dist) % WinSize]; 414 pos++; 415 if (pos == WinSize) { 416 s->pos = WinSize; 417 s->lenpos = len; 418 s->nclen = dist; /* using nclen to store dist */ 419 s->state = DecodeBlockCopy; 420 return FlateOut; 421 } 422 } 423 } 424 } else { /* EOB: sym == 256 */ 425 s->pos = pos; 426 return FlateOk; 427 } 428 } /* for (;;) */ 429 } /* switch () */ 430 return s->err = "corrupted state.", FlateErr; 431 } 432 433 /* inflate state machine (decodes s->src into s->win) */ 434 static int inflate_state(State *s) { 435 int n; 436 437 if (s->posout) 438 return FlateOut; 439 for (;;) { 440 switch (s->state) { 441 case BlockHead: 442 if (s->final) { 443 if (s->pos) 444 return FlateOut; 445 else 446 return FlateOk; 447 } 448 if (!fillbits(s, 3)) 449 return FlateIn; 450 s->final = getbits(s, 1); 451 n = getbits(s, 2); 452 if (n == 0) 453 s->state = UncompressedBlock; 454 else if (n == 1) 455 s->state = FixedHuff; 456 else if (n == 2) 457 s->state = DynamicHuff; 458 else 459 return s->err = "corrupt block header.", FlateErr; 460 break; 461 case UncompressedBlock: 462 /* start block on a byte boundary */ 463 s->bits >>= s->nbits & 7; 464 s->nbits &= ~7; 465 if (!fillbits(s, 32)) 466 return FlateIn; 467 s->lenpos = getbits(s, 16); 468 n = getbits(s, 16); 469 if (s->lenpos != (~n & 0xffff)) 470 return s->err = "corrupt uncompressed length.", FlateErr; 471 s->state = CopyUncompressed; 472 case CopyUncompressed: 473 /* TODO: untested, slow, memcpy etc */ 474 /* s->nbits should be 0 here */ 475 while (s->lenpos) { 476 if (s->src == s->srcend) 477 return FlateIn; 478 s->lenpos--; 479 s->win[s->pos++] = *s->src++; 480 if (s->pos == WinSize) 481 return FlateOut; 482 } 483 s->state = BlockHead; 484 break; 485 case FixedHuff: 486 s->fixed = 1; 487 s->state = DecodeBlock; 488 break; 489 case DynamicHuff: 490 /* decode dynamic huffman code trees */ 491 if (!fillbits(s, 14)) 492 return FlateIn; 493 s->nlit = 257 + getbits(s, 5); 494 s->ndist = 1 + getbits(s, 5); 495 s->nclen = 4 + getbits(s, 4); 496 if (s->nlit > Nlitlen || s->ndist > Ndist) 497 return s->err = "corrupt code tree.", FlateErr; 498 /* build code length tree */ 499 for (n = 0; n < Nclen; n++) 500 s->lens[n] = 0; 501 s->fixed = 0; 502 s->state = DynamicHuffClen; 503 s->lenpos = 0; 504 case DynamicHuffClen: 505 for (n = s->lenpos; n < s->nclen; n++) 506 if (fillbits(s, 3)) { 507 s->lens[clenorder[n]] = getbits(s, 3); 508 } else { 509 s->lenpos = n; 510 return FlateIn; 511 } 512 /* using lhuff for code length huff code */ 513 if (build_huff(&s->lhuff, s->lens, Nclen, ClenTableBits) < 0) 514 return s->err = "building clen tree failed.", FlateErr; 515 s->state = DynamicHuffLitlenDist; 516 s->lenpos = 0; 517 case DynamicHuffLitlenDist: 518 /* decode code lengths for the dynamic trees */ 519 for (n = s->lenpos; n < s->nlit + s->ndist; ) { 520 uint sym = decode_symbol(s, &s->lhuff); 521 uint len; 522 uchar c; 523 524 if (sym < 16) { 525 s->lens[n++] = sym; 526 continue; 527 } else if (sym == (uint)FlateIn) { 528 s->lenpos = n; 529 return FlateIn; 530 case DynamicHuffContinue: 531 n = s->lenpos; 532 sym = s->nclen; 533 s->state = DynamicHuffLitlenDist; 534 } 535 if (!fillbits(s, 7)) { 536 /* TODO: 7 is too much when an almost empty block is at the end */ 537 if (sym == (uint)FlateErr) 538 return FlateErr; 539 s->nclen = sym; 540 s->lenpos = n; 541 s->state = DynamicHuffContinue; 542 return FlateIn; 543 } 544 /* TODO: bound check s->lens */ 545 if (sym == 16) { 546 /* copy previous code length 3-6 times */ 547 c = s->lens[n - 1]; 548 for (len = 3 + getbits(s, 2); len; len--) 549 s->lens[n++] = c; 550 } else if (sym == 17) { 551 /* repeat 0 for 3-10 times */ 552 for (len = 3 + getbits(s, 3); len; len--) 553 s->lens[n++] = 0; 554 } else if (sym == 18) { 555 /* repeat 0 for 11-138 times */ 556 for (len = 11 + getbits(s, 7); len; len--) 557 s->lens[n++] = 0; 558 } else 559 return s->err = "corrupt code tree.", FlateErr; 560 } 561 /* build dynamic huffman code trees */ 562 if (build_huff(&s->lhuff, s->lens, s->nlit, LitlenTableBits) < 0) 563 return s->err = "building litlen tree failed.", FlateErr; 564 if (build_huff(&s->dhuff, s->lens + s->nlit, s->ndist, DistTableBits) < 0) 565 return s->err = "building dist tree failed.", FlateErr; 566 s->state = DecodeBlock; 567 case DecodeBlock: 568 case DecodeBlockLenBits: 569 case DecodeBlockDist: 570 case DecodeBlockDistBits: 571 case DecodeBlockCopy: 572 n = decode_block(s, s->fixed ? &lhuff : &s->lhuff, s->fixed ? &dhuff : &s->dhuff); 573 if (n != FlateOk) 574 return n; 575 s->state = BlockHead; 576 break; 577 default: 578 return s->err = "corrupt internal state.", FlateErr; 579 } 580 } 581 } 582 583 static State *alloc_state(void) { 584 State *s = malloc(sizeof(State)); 585 586 if (s) { 587 s->final = s->pos = s->posout = s->bits = s->nbits = 0; 588 s->state = BlockHead; 589 s->src = s->srcend = 0; 590 s->err = 0; 591 /* TODO: globals.. */ 592 if (lhuff.nbits == 0) 593 init_fixed_huffs(); 594 } 595 return s; 596 } 597 598 599 /* extern */ 600 601 int inflate(FlateStream *stream) { 602 State *s = stream->state; 603 int n; 604 605 if (stream->err) { 606 if (s) { 607 free(s); 608 stream->state = 0; 609 } 610 return FlateErr; 611 } 612 if (!s) { 613 s = stream->state = alloc_state(); 614 if (!s) 615 return stream->err = "no mem.", FlateErr; 616 } 617 if (stream->nin) { 618 s->src = stream->in; 619 s->srcend = s->src + stream->nin; 620 stream->nin = 0; 621 } 622 n = inflate_state(s); 623 if (n == FlateOut) { 624 if (s->pos < stream->nout) 625 stream->nout = s->pos; 626 memcpy(stream->out, s->win + s->posout, stream->nout); 627 s->pos -= stream->nout; 628 if (s->pos) 629 s->posout += stream->nout; 630 else 631 s->posout = 0; 632 } 633 if (n == FlateOk || n == FlateErr) { 634 if (s->nbits || s->src < s->srcend) { 635 s->nbits /= 8; 636 stream->in = s->src - s->nbits; 637 stream->nin = s->srcend - s->src + s->nbits; 638 } 639 stream->err = s->err; 640 free(s); 641 stream->state = 0; 642 } 643 return n; 644 }