flate

deflate implementation
git clone git://git.suckless.org/flate
Log | Files | Refs | README

commit 966429e2a5d9e15ac6ee0a47eb715b8dfd04550f
parent 96787dc8ded9ff295326fe1695e1eccf144bf0d6
Author: nsz <nszabolcs@gmail.com>
Date:   Mon, 27 Apr 2009 23:09:29 +0200

inflate cleanup
Diffstat:
inflate.c | 237+++++++++++++++++++++++++++++++++++++++++--------------------------------------
1 file changed, 122 insertions(+), 115 deletions(-)

diff --git a/inflate.c b/inflate.c @@ -1,12 +1,8 @@ /* TODO: check int types - error check (src len, dst len, hufftree) + better error handling clever io - optimization: - bottleneck: decode_symbol in decode_block - clever huff table - fillbits(s, n).. */ typedef unsigned char uchar; @@ -14,37 +10,38 @@ typedef unsigned short ushort; typedef unsigned int uint; enum { - HuffBits = 16, /* max number of bits in a code */ - HuffLitlenBits = 9, /* log2(litlen huff table size) */ - HuffDistBits = 6, /* log2(dist huff table size) */ - HuffClenBits = 6, /* log2(clen huff table size) */ - HuffTableBits = HuffLitlenBits, /* log2(max huff table size) */ - Nlit = 256, /* number of lit codes */ - Nlen = 29, /* number of len codes */ - Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */ - Ndist = 30, /* number of distance codes */ - Nclen = 19 /* number of code length codes */ + CodeBits = 16, /* max number of bits in a code */ + LitlenTableBits = 9, /* log2(litlen lookup table size) */ + DistTableBits = 6, /* log2(dist lookup table size) */ + ClenTableBits = 6, /* log2(clen lookup table size) */ + TableBits = LitlenTableBits, /* log2(max lookup table size) */ + Nlit = 256, /* number of lit codes */ + Nlen = 29, /* number of len codes */ + Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */ + Ndist = 30, /* number of distance codes */ + Nclen = 19 /* number of code length codes */ }; enum { - FlateOk = 0, - FlateShortSrc = -2, - FlateShortDst = -3, - FlateCorrupted = -4 + FlateOk = 0, + FlateShortSrc = -1, + FlateShortDst = -2, + FlateCorrupted = -3 }; typedef struct { - short len; /* complete: code len, incomplete: -(extra bits), err: 0 */ - ushort sym; /* symbol if complete, decode helper if incomplete */ -} HuffEntry; + short len; /* code length */ + ushort sym; /* symbol */ +} Entry; +/* huffman code tree */ typedef struct { - HuffEntry table[1 << HuffTableBits]; /* for decoding the first nbits */ - int nbits; /* table length is 1 << nbits */ - int sum; /* sum(count[0..nbits-1]) */ - ushort count[HuffBits]; /* code bit length -> count */ - ushort symbol[Nlitlen]; /* symbols ordered by code length */ -} HuffTable; + Entry table[1 << TableBits]; /* prefix lookup table */ + uint nbits; /* prefix length (table size is 1 << nbits) */ + uint sum; /* full codes in table: sum(count[0..nbits]) */ + ushort count[CodeBits]; /* number of codes with given length */ + ushort symbol[Nlitlen]; /* symbols ordered by code length (lexic.) */ +} Huff; typedef struct { uchar *src; @@ -59,13 +56,13 @@ typedef struct { int error; - HuffTable ltab; /* dynamic lit/len table */ - HuffTable dtab; /* dynamic distance table */ -} FlateStream; + Huff lhuff; /* dynamic lit/len huffman code tree */ + Huff dhuff; /* dynamic distance huffman code tree */ +} Stream; - -static HuffTable ltab; /* fixed lit/len table */ -static HuffTable dtab; /* fixed distance table */ +/* TODO: these globals are initialized in a lazy way (not thread safe) */ +static Huff lhuff; /* fixed lit/len huffman code tree */ +static Huff dhuff; /* fixed distance huffman code tree */ /* base offset and extra bits tables */ static uchar lenbits[Nlen] = { @@ -86,8 +83,9 @@ static uchar clenorder[Nclen] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; +/* TODO: this or normal inc + reverse() */ /* increment bitwise reversed n (msb is bit 0, lsb is bit len-1) */ -static uint revinc(uint n, int len) { +static uint revinc(uint n, uint len) { uint i; i = 1 << (len - 1); @@ -101,18 +99,18 @@ static uint revinc(uint n, int len) { return n; } -/* build huffman code tree from code lengths (each should be < HuffBits) */ -static int build_table(HuffTable *h, uchar *lens, int n, int nbits) { - int offs[HuffBits]; +/* build huffman code tree from code lengths (each should be < CodeBits) */ +static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) { + int offs[CodeBits]; int left; uint i, c, sum, code, len, min, max; - ushort *count = h->count; - ushort *symbol = h->symbol; - HuffEntry *table = h->table; - HuffEntry entry; + ushort *count = huff->count; + ushort *symbol = huff->symbol; + Entry *table = huff->table; + Entry entry; - /* count code lengths and calc first code for each length */ - for (i = 0; i < HuffBits; i++) + /* count code lengths */ + for (i = 0; i < CodeBits; i++) count[i] = 0; for (i = 0; i < n; i++) count[lens[i]]++; @@ -120,18 +118,18 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) { return -1; count[0] = 0; - /* bound code lengths, force root to be within code lengths */ - for (max = HuffBits - 1; max > 0; max--) + /* bound code lengths, force nbits to be within code lengths */ + for (max = CodeBits - 1; max > 0; max--) if (count[max] != 0) break; if (nbits > max) nbits = max; - for (min = 1; min < HuffBits; min++) + for (min = 1; min < CodeBits; min++) if (count[min] != 0) break; if (nbits < min) nbits = min; - h->nbits = nbits; + huff->nbits = nbits; /* check if length is over-subscribed or incomplete */ for (left = 1 << min, i = min; i <= max; left <<= 1, i++) { @@ -147,9 +145,9 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) { } /* needed for decoding codes longer than nbits */ if (nbits < max) - h->sum = offs[nbits + 1]; + huff->sum = offs[nbits + 1]; - /* sort symbols by code length */ + /* sort symbols by code length (lexicographic order) */ for (i = 0; i < n; i++) if (lens[i]) symbol[offs[lens[i]]++] = i; @@ -158,7 +156,7 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) { for (i = 0; i < 1 << nbits; i++) table[i].len = 0; /* invalid marker for incomplete code */ code = 0; - /* ..if code is at most nbits */ + /* ..if code is at most nbits (bits are in reverse order, sigh..) */ for (len = min; len <= nbits; len++) for (c = count[len]; c > 0; c--) { entry.len = len; @@ -169,16 +167,17 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) { symbol++; code = revinc(code, len); } - /* ..if code is longer than nbits: nbits prefixes are marked */ + /* ..if code is longer than nbits: values for simple bitwise decode */ for (i = 0; code; i++) { - table[code].len = -1; /* TODO +validation? */ - table[code].sym = i; + table[code].len = -1; + table[code].sym = i << 1; code = revinc(code, nbits); } return 0; } -static void init_fixed_trees(void) { +/* fixed huffman code trees (should be done at compile time..) */ +static void init_fixed_huffs(void) { int i; uchar lens[Nlitlen]; @@ -188,18 +187,17 @@ static void init_fixed_trees(void) { lens[i] = 9; for (; i < 280; i++) lens[i] = 7; - for (; i < 288; i++) /* a complete, but wrong code set */ + for (; i < Nlitlen; i++) lens[i] = 8; - build_table(&ltab, lens, Nlitlen, 8); - + build_huff(&lhuff, lens, Nlitlen, 8); - for (i = 0; i < Ndist; i++) /* an incomplete code set */ + for (i = 0; i < Ndist; i++) lens[i] = 5; - build_table(&dtab, lens, Ndist, 5); + build_huff(&dhuff, lens, Ndist, 5); } /* get one bit from s->src stream */ -static uint getbit(FlateStream *s) { +static uint getbit(Stream *s) { uint bit; if (!s->nbits--) { @@ -216,7 +214,7 @@ static uint getbit(FlateStream *s) { } /* get n bits from s->src stream */ -static uint getbits(FlateStream *s, int n) { +static uint getbits(Stream *s, int n) { uint bits = 0; int i; @@ -228,35 +226,41 @@ static uint getbits(FlateStream *s, int n) { } /* decode a symbol from stream with tree */ -static uint decode_symbol(FlateStream *s, HuffTable *h) { +static uint decode_symbol(Stream *s, Huff *huff) { int sum, cur; ushort *count; - uint htbits = h->nbits; - uint nbits = s->nbits; + uint huffbits = huff->nbits; + uint streambits = s->nbits; uint bits = s->bits; - uint mask = (1 << htbits) - 1; - HuffEntry entry; + uint mask = (1 << huffbits) - 1; + Entry entry; - /* TODO: check src */ - while (nbits < htbits) { - bits |= (*s->src++ << nbits); - nbits += 8; + /* avail src should be checked outside */ + while (streambits < huffbits) { + /* TODO: decode_symbol is performace bottleneck, do it faster */ + if (s->src == s->srcend) { + s->error = FlateShortSrc; + return 0; + } + bits |= (*s->src++ << streambits); + streambits += 8; } - entry = h->table[bits & mask]; + entry = huff->table[bits & mask]; if (entry.len > 0) { s->bits = bits >> entry.len; - s->nbits = nbits - entry.len; + s->nbits = streambits - entry.len; return entry.sym; - } else if (entry.len == 0) - /* error */; - - bits >>= htbits; - nbits -= htbits; - s->bits = bits; - s->nbits = nbits; - cur = entry.sym << 1; /* TODO: do it in build_tab */ - sum = h->sum; - count = h->count + htbits + 1; + } else if (entry.len == 0) { + s->error = FlateCorrupted; + return 0; + } + + /* code is longer than huffbits: bitwise decode the rest */ + s->bits = bits >> huffbits; + s->nbits = streambits - huffbits; + cur = entry.sym; + sum = huff->sum; + count = huff->count + huffbits + 1; for (;;) { cur |= getbit(s); sum += *count; @@ -265,13 +269,17 @@ static uint decode_symbol(FlateStream *s, HuffTable *h) { break; cur <<= 1; count++; + if (count == huff->count + CodeBits) { + s->error = FlateCorrupted; + return 0; + } } - return h->symbol[sum + cur]; + return huff->symbol[sum + cur]; } -/* decode dynamic trees from stream */ -static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) { - HuffTable cltab; +/* decode dynamic huffman code trees from stream */ +static void decode_huffs(Stream *s) { + Huff clhuff; uchar lens[Nlitlen+Ndist]; uint nlit, ndist, nclen; uint i; @@ -291,18 +299,17 @@ static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) { lens[i] = 0; for (i = 0; i < nclen; i++) lens[clenorder[i]] = getbits(s, 3); - if (build_table(&cltab, lens, Nclen, HuffClenBits) < 0) { + if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) { s->error = FlateCorrupted; return; } /* decode code lengths for the dynamic trees */ for (i = 0; i < nlit + ndist; ) { - uint sym = decode_symbol(s, &cltab); + uint sym = decode_symbol(s, &clhuff); uint len; uchar c; -/* fprintf(stderr, "sym: %u, nbits: %u, bits: x%03x, src: %u\n", sym, s->nbits, s->bits, (unsigned)s->src);*/ if (sym < 16) { lens[i++] = sym; } else if (sym == 16) { @@ -323,20 +330,19 @@ static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) { if (s->error != FlateOk) return; } - /* build dynamic huffman trees */ - if (build_table(lt, lens, nlit, HuffLitlenBits) < 0) + /* build dynamic huffman code trees */ + if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0) s->error = FlateCorrupted; - if (build_table(dt, lens + nlit, ndist, HuffDistBits) < 0) + if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0) s->error = FlateCorrupted; } /* decode a block of data from stream with trees */ -static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) { +static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) { uint sym; for (;;) { - sym = decode_symbol(s, lt); - + sym = decode_symbol(s, lhuff); if (s->error != FlateOk) return; if (sym == 256) @@ -356,7 +362,7 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) { return; } len = lenbase[sym] + getbits(s, lenbits[sym]); - sym = decode_symbol(s, dt); + sym = decode_symbol(s, dhuff); if (s->error != FlateOk) return; if (sym >= Ndist) { @@ -375,6 +381,7 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) { s->error = FlateCorrupted; return; } + /* TODO: unroll 3-4 loops */ for (i = 0; i < len; i++) s->dst[i] = s->dst[i - dist]; s->dst += len; @@ -382,9 +389,15 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) { } } -static void inflate_uncompressed_block(FlateStream *s) { +static void inflate_uncompressed_block(Stream *s) { uint len, invlen; + /* start next block on a byte boundary */ + while (s->nbits > 8) { + s->nbits -= 8; + s->src--; + } + s->nbits = 0; len = (s->src[1] << 8) | s->src[0]; invlen = (s->src[3] << 8) | s->src[2]; s->src += 4; @@ -392,10 +405,6 @@ static void inflate_uncompressed_block(FlateStream *s) { s->error = FlateCorrupted; return; } - - /* start next block on a byte boundary */ - s->nbits = 0; - if (s->dst + len > s->dstend) { s->error = FlateShortDst; return; @@ -404,32 +413,31 @@ static void inflate_uncompressed_block(FlateStream *s) { s->error = FlateShortSrc; return; } + /* TODO: memcpy */ while (len--) *s->dst++ = *s->src++; } -static void inflate_fixed_block(FlateStream *s) { - decode_block(s, &ltab, &dtab); +static void inflate_fixed_block(Stream *s) { + /* lazy initialization of fixed huff code trees */ + if (lhuff.nbits == 0) + init_fixed_huffs(); + decode_block(s, &lhuff, &dhuff); } -static void inflate_dynamic_block(FlateStream *s) { - decode_tables(s, &s->ltab, &s->dtab); +static void inflate_dynamic_block(Stream *s) { + decode_huffs(s); if (s->error != FlateOk) return; - decode_block(s, &s->ltab, &s->dtab); + decode_block(s, &s->lhuff, &s->dhuff); } -/* extern functions */ - -/* initialize global (static) data */ -void inflate_init(void) { - init_fixed_trees(); -} +/* extern */ /* inflate stream from src to dst */ uint inflate(void *dst, uint dstlen, void *src, uint srclen) { - FlateStream s; + Stream s; uint final; s.src = src; @@ -486,7 +494,6 @@ int main(int argc, char **argv) { uint len = 1 << 24; uchar *src, *dst; - inflate_init(); src = readall(argc < 2 ? stdin : fopen(argv[1], "r")); dst = malloc(len); len = inflate(dst, len, src, len>>2);