commit 966429e2a5d9e15ac6ee0a47eb715b8dfd04550f
parent 96787dc8ded9ff295326fe1695e1eccf144bf0d6
Author: nsz <nszabolcs@gmail.com>
Date: Mon, 27 Apr 2009 23:09:29 +0200
inflate cleanup
Diffstat:
inflate.c | | | 237 | +++++++++++++++++++++++++++++++++++++++++-------------------------------------- |
1 file changed, 122 insertions(+), 115 deletions(-)
diff --git a/inflate.c b/inflate.c
@@ -1,12 +1,8 @@
/*
TODO:
check int types
- error check (src len, dst len, hufftree)
+ better error handling
clever io
- optimization:
- bottleneck: decode_symbol in decode_block
- clever huff table
- fillbits(s, n)..
*/
typedef unsigned char uchar;
@@ -14,37 +10,38 @@ typedef unsigned short ushort;
typedef unsigned int uint;
enum {
- HuffBits = 16, /* max number of bits in a code */
- HuffLitlenBits = 9, /* log2(litlen huff table size) */
- HuffDistBits = 6, /* log2(dist huff table size) */
- HuffClenBits = 6, /* log2(clen huff table size) */
- HuffTableBits = HuffLitlenBits, /* log2(max huff table size) */
- Nlit = 256, /* number of lit codes */
- Nlen = 29, /* number of len codes */
- Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */
- Ndist = 30, /* number of distance codes */
- Nclen = 19 /* number of code length codes */
+ CodeBits = 16, /* max number of bits in a code */
+ LitlenTableBits = 9, /* log2(litlen lookup table size) */
+ DistTableBits = 6, /* log2(dist lookup table size) */
+ ClenTableBits = 6, /* log2(clen lookup table size) */
+ TableBits = LitlenTableBits, /* log2(max lookup table size) */
+ Nlit = 256, /* number of lit codes */
+ Nlen = 29, /* number of len codes */
+ Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */
+ Ndist = 30, /* number of distance codes */
+ Nclen = 19 /* number of code length codes */
};
enum {
- FlateOk = 0,
- FlateShortSrc = -2,
- FlateShortDst = -3,
- FlateCorrupted = -4
+ FlateOk = 0,
+ FlateShortSrc = -1,
+ FlateShortDst = -2,
+ FlateCorrupted = -3
};
typedef struct {
- short len; /* complete: code len, incomplete: -(extra bits), err: 0 */
- ushort sym; /* symbol if complete, decode helper if incomplete */
-} HuffEntry;
+ short len; /* code length */
+ ushort sym; /* symbol */
+} Entry;
+/* huffman code tree */
typedef struct {
- HuffEntry table[1 << HuffTableBits]; /* for decoding the first nbits */
- int nbits; /* table length is 1 << nbits */
- int sum; /* sum(count[0..nbits-1]) */
- ushort count[HuffBits]; /* code bit length -> count */
- ushort symbol[Nlitlen]; /* symbols ordered by code length */
-} HuffTable;
+ Entry table[1 << TableBits]; /* prefix lookup table */
+ uint nbits; /* prefix length (table size is 1 << nbits) */
+ uint sum; /* full codes in table: sum(count[0..nbits]) */
+ ushort count[CodeBits]; /* number of codes with given length */
+ ushort symbol[Nlitlen]; /* symbols ordered by code length (lexic.) */
+} Huff;
typedef struct {
uchar *src;
@@ -59,13 +56,13 @@ typedef struct {
int error;
- HuffTable ltab; /* dynamic lit/len table */
- HuffTable dtab; /* dynamic distance table */
-} FlateStream;
+ Huff lhuff; /* dynamic lit/len huffman code tree */
+ Huff dhuff; /* dynamic distance huffman code tree */
+} Stream;
-
-static HuffTable ltab; /* fixed lit/len table */
-static HuffTable dtab; /* fixed distance table */
+/* TODO: these globals are initialized in a lazy way (not thread safe) */
+static Huff lhuff; /* fixed lit/len huffman code tree */
+static Huff dhuff; /* fixed distance huffman code tree */
/* base offset and extra bits tables */
static uchar lenbits[Nlen] = {
@@ -86,8 +83,9 @@ static uchar clenorder[Nclen] = {
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};
+/* TODO: this or normal inc + reverse() */
/* increment bitwise reversed n (msb is bit 0, lsb is bit len-1) */
-static uint revinc(uint n, int len) {
+static uint revinc(uint n, uint len) {
uint i;
i = 1 << (len - 1);
@@ -101,18 +99,18 @@ static uint revinc(uint n, int len) {
return n;
}
-/* build huffman code tree from code lengths (each should be < HuffBits) */
-static int build_table(HuffTable *h, uchar *lens, int n, int nbits) {
- int offs[HuffBits];
+/* build huffman code tree from code lengths (each should be < CodeBits) */
+static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) {
+ int offs[CodeBits];
int left;
uint i, c, sum, code, len, min, max;
- ushort *count = h->count;
- ushort *symbol = h->symbol;
- HuffEntry *table = h->table;
- HuffEntry entry;
+ ushort *count = huff->count;
+ ushort *symbol = huff->symbol;
+ Entry *table = huff->table;
+ Entry entry;
- /* count code lengths and calc first code for each length */
- for (i = 0; i < HuffBits; i++)
+ /* count code lengths */
+ for (i = 0; i < CodeBits; i++)
count[i] = 0;
for (i = 0; i < n; i++)
count[lens[i]]++;
@@ -120,18 +118,18 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) {
return -1;
count[0] = 0;
- /* bound code lengths, force root to be within code lengths */
- for (max = HuffBits - 1; max > 0; max--)
+ /* bound code lengths, force nbits to be within code lengths */
+ for (max = CodeBits - 1; max > 0; max--)
if (count[max] != 0)
break;
if (nbits > max)
nbits = max;
- for (min = 1; min < HuffBits; min++)
+ for (min = 1; min < CodeBits; min++)
if (count[min] != 0)
break;
if (nbits < min)
nbits = min;
- h->nbits = nbits;
+ huff->nbits = nbits;
/* check if length is over-subscribed or incomplete */
for (left = 1 << min, i = min; i <= max; left <<= 1, i++) {
@@ -147,9 +145,9 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) {
}
/* needed for decoding codes longer than nbits */
if (nbits < max)
- h->sum = offs[nbits + 1];
+ huff->sum = offs[nbits + 1];
- /* sort symbols by code length */
+ /* sort symbols by code length (lexicographic order) */
for (i = 0; i < n; i++)
if (lens[i])
symbol[offs[lens[i]]++] = i;
@@ -158,7 +156,7 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) {
for (i = 0; i < 1 << nbits; i++)
table[i].len = 0; /* invalid marker for incomplete code */
code = 0;
- /* ..if code is at most nbits */
+ /* ..if code is at most nbits (bits are in reverse order, sigh..) */
for (len = min; len <= nbits; len++)
for (c = count[len]; c > 0; c--) {
entry.len = len;
@@ -169,16 +167,17 @@ static int build_table(HuffTable *h, uchar *lens, int n, int nbits) {
symbol++;
code = revinc(code, len);
}
- /* ..if code is longer than nbits: nbits prefixes are marked */
+ /* ..if code is longer than nbits: values for simple bitwise decode */
for (i = 0; code; i++) {
- table[code].len = -1; /* TODO +validation? */
- table[code].sym = i;
+ table[code].len = -1;
+ table[code].sym = i << 1;
code = revinc(code, nbits);
}
return 0;
}
-static void init_fixed_trees(void) {
+/* fixed huffman code trees (should be done at compile time..) */
+static void init_fixed_huffs(void) {
int i;
uchar lens[Nlitlen];
@@ -188,18 +187,17 @@ static void init_fixed_trees(void) {
lens[i] = 9;
for (; i < 280; i++)
lens[i] = 7;
- for (; i < 288; i++) /* a complete, but wrong code set */
+ for (; i < Nlitlen; i++)
lens[i] = 8;
- build_table(<ab, lens, Nlitlen, 8);
-
+ build_huff(&lhuff, lens, Nlitlen, 8);
- for (i = 0; i < Ndist; i++) /* an incomplete code set */
+ for (i = 0; i < Ndist; i++)
lens[i] = 5;
- build_table(&dtab, lens, Ndist, 5);
+ build_huff(&dhuff, lens, Ndist, 5);
}
/* get one bit from s->src stream */
-static uint getbit(FlateStream *s) {
+static uint getbit(Stream *s) {
uint bit;
if (!s->nbits--) {
@@ -216,7 +214,7 @@ static uint getbit(FlateStream *s) {
}
/* get n bits from s->src stream */
-static uint getbits(FlateStream *s, int n) {
+static uint getbits(Stream *s, int n) {
uint bits = 0;
int i;
@@ -228,35 +226,41 @@ static uint getbits(FlateStream *s, int n) {
}
/* decode a symbol from stream with tree */
-static uint decode_symbol(FlateStream *s, HuffTable *h) {
+static uint decode_symbol(Stream *s, Huff *huff) {
int sum, cur;
ushort *count;
- uint htbits = h->nbits;
- uint nbits = s->nbits;
+ uint huffbits = huff->nbits;
+ uint streambits = s->nbits;
uint bits = s->bits;
- uint mask = (1 << htbits) - 1;
- HuffEntry entry;
+ uint mask = (1 << huffbits) - 1;
+ Entry entry;
- /* TODO: check src */
- while (nbits < htbits) {
- bits |= (*s->src++ << nbits);
- nbits += 8;
+ /* avail src should be checked outside */
+ while (streambits < huffbits) {
+ /* TODO: decode_symbol is performace bottleneck, do it faster */
+ if (s->src == s->srcend) {
+ s->error = FlateShortSrc;
+ return 0;
+ }
+ bits |= (*s->src++ << streambits);
+ streambits += 8;
}
- entry = h->table[bits & mask];
+ entry = huff->table[bits & mask];
if (entry.len > 0) {
s->bits = bits >> entry.len;
- s->nbits = nbits - entry.len;
+ s->nbits = streambits - entry.len;
return entry.sym;
- } else if (entry.len == 0)
- /* error */;
-
- bits >>= htbits;
- nbits -= htbits;
- s->bits = bits;
- s->nbits = nbits;
- cur = entry.sym << 1; /* TODO: do it in build_tab */
- sum = h->sum;
- count = h->count + htbits + 1;
+ } else if (entry.len == 0) {
+ s->error = FlateCorrupted;
+ return 0;
+ }
+
+ /* code is longer than huffbits: bitwise decode the rest */
+ s->bits = bits >> huffbits;
+ s->nbits = streambits - huffbits;
+ cur = entry.sym;
+ sum = huff->sum;
+ count = huff->count + huffbits + 1;
for (;;) {
cur |= getbit(s);
sum += *count;
@@ -265,13 +269,17 @@ static uint decode_symbol(FlateStream *s, HuffTable *h) {
break;
cur <<= 1;
count++;
+ if (count == huff->count + CodeBits) {
+ s->error = FlateCorrupted;
+ return 0;
+ }
}
- return h->symbol[sum + cur];
+ return huff->symbol[sum + cur];
}
-/* decode dynamic trees from stream */
-static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) {
- HuffTable cltab;
+/* decode dynamic huffman code trees from stream */
+static void decode_huffs(Stream *s) {
+ Huff clhuff;
uchar lens[Nlitlen+Ndist];
uint nlit, ndist, nclen;
uint i;
@@ -291,18 +299,17 @@ static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) {
lens[i] = 0;
for (i = 0; i < nclen; i++)
lens[clenorder[i]] = getbits(s, 3);
- if (build_table(&cltab, lens, Nclen, HuffClenBits) < 0) {
+ if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) {
s->error = FlateCorrupted;
return;
}
/* decode code lengths for the dynamic trees */
for (i = 0; i < nlit + ndist; ) {
- uint sym = decode_symbol(s, &cltab);
+ uint sym = decode_symbol(s, &clhuff);
uint len;
uchar c;
-/* fprintf(stderr, "sym: %u, nbits: %u, bits: x%03x, src: %u\n", sym, s->nbits, s->bits, (unsigned)s->src);*/
if (sym < 16) {
lens[i++] = sym;
} else if (sym == 16) {
@@ -323,20 +330,19 @@ static void decode_tables(FlateStream *s, HuffTable *lt, HuffTable *dt) {
if (s->error != FlateOk)
return;
}
- /* build dynamic huffman trees */
- if (build_table(lt, lens, nlit, HuffLitlenBits) < 0)
+ /* build dynamic huffman code trees */
+ if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0)
s->error = FlateCorrupted;
- if (build_table(dt, lens + nlit, ndist, HuffDistBits) < 0)
+ if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0)
s->error = FlateCorrupted;
}
/* decode a block of data from stream with trees */
-static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) {
+static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) {
uint sym;
for (;;) {
- sym = decode_symbol(s, lt);
-
+ sym = decode_symbol(s, lhuff);
if (s->error != FlateOk)
return;
if (sym == 256)
@@ -356,7 +362,7 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) {
return;
}
len = lenbase[sym] + getbits(s, lenbits[sym]);
- sym = decode_symbol(s, dt);
+ sym = decode_symbol(s, dhuff);
if (s->error != FlateOk)
return;
if (sym >= Ndist) {
@@ -375,6 +381,7 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) {
s->error = FlateCorrupted;
return;
}
+ /* TODO: unroll 3-4 loops */
for (i = 0; i < len; i++)
s->dst[i] = s->dst[i - dist];
s->dst += len;
@@ -382,9 +389,15 @@ static void decode_block(FlateStream *s, HuffTable *lt, HuffTable *dt) {
}
}
-static void inflate_uncompressed_block(FlateStream *s) {
+static void inflate_uncompressed_block(Stream *s) {
uint len, invlen;
+ /* start next block on a byte boundary */
+ while (s->nbits > 8) {
+ s->nbits -= 8;
+ s->src--;
+ }
+ s->nbits = 0;
len = (s->src[1] << 8) | s->src[0];
invlen = (s->src[3] << 8) | s->src[2];
s->src += 4;
@@ -392,10 +405,6 @@ static void inflate_uncompressed_block(FlateStream *s) {
s->error = FlateCorrupted;
return;
}
-
- /* start next block on a byte boundary */
- s->nbits = 0;
-
if (s->dst + len > s->dstend) {
s->error = FlateShortDst;
return;
@@ -404,32 +413,31 @@ static void inflate_uncompressed_block(FlateStream *s) {
s->error = FlateShortSrc;
return;
}
+ /* TODO: memcpy */
while (len--)
*s->dst++ = *s->src++;
}
-static void inflate_fixed_block(FlateStream *s) {
- decode_block(s, <ab, &dtab);
+static void inflate_fixed_block(Stream *s) {
+ /* lazy initialization of fixed huff code trees */
+ if (lhuff.nbits == 0)
+ init_fixed_huffs();
+ decode_block(s, &lhuff, &dhuff);
}
-static void inflate_dynamic_block(FlateStream *s) {
- decode_tables(s, &s->ltab, &s->dtab);
+static void inflate_dynamic_block(Stream *s) {
+ decode_huffs(s);
if (s->error != FlateOk)
return;
- decode_block(s, &s->ltab, &s->dtab);
+ decode_block(s, &s->lhuff, &s->dhuff);
}
-/* extern functions */
-
-/* initialize global (static) data */
-void inflate_init(void) {
- init_fixed_trees();
-}
+/* extern */
/* inflate stream from src to dst */
uint inflate(void *dst, uint dstlen, void *src, uint srclen) {
- FlateStream s;
+ Stream s;
uint final;
s.src = src;
@@ -486,7 +494,6 @@ int main(int argc, char **argv) {
uint len = 1 << 24;
uchar *src, *dst;
- inflate_init();
src = readall(argc < 2 ? stdin : fopen(argv[1], "r"));
dst = malloc(len);
len = inflate(dst, len, src, len>>2);