flate

deflate implementation
git clone git://git.suckless.org/flate
Log | Files | Refs | README

commit db64c049e9f97322b457aca518b1899ce5444dcf
parent 180021497c5184b5eb3f713cb907266ac3407b46
Author: nsz <nszabolcs@gmail.com>
Date:   Sun,  7 Jun 2009 11:41:19 +0200

inflate with state
Diffstat:
Makefile | 13++++++++-----
inflate.c | 664+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
inflate_callback.c | 590+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 993 insertions(+), 274 deletions(-)

diff --git a/Makefile b/Makefile @@ -1,24 +1,27 @@ #CFLAGS=-g -Wall -ansi -pedantic -CFLAGS=-Os -Wall -ansi -pedantic +CFLAGS=-O3 -Wall -ansi -pedantic LDFLAGS= -SRC=inflate.c inflate_simple.c +SRC=inflate.c inflate_simple.c inflate_callback.c OBJ=${SRC:.c=.o} +EXE=${SRC:.c=} -all: inflate inflate_simple +all: ${EXE} inflate: inflate.o ${CC} -o $@ $^ ${LDFLAGS} inflate_simple: inflate_simple.o ${CC} -o $@ $^ ${LDFLAGS} +inflate_callback: inflate_callback.o + ${CC} -o $@ $^ ${LDFLAGS} ${OBJ}: Makefile .c.o: ${CC} -c ${CFLAGS} $< clean: - rm -f ${OBJ} inflate + rm -f ${OBJ} ${EXE} prof: inflate.c gcc -fprofile-arcs -ftest-coverage -pg -g -Wall $< - cat d.dat | ./a.out > /dev/null + cat a.dat | ./a.out > /dev/null gcov -b $< > /dev/null gprof a.out > $<.gprof gcc -g -Wall $< diff --git a/inflate.c b/inflate.c @@ -1,4 +1,5 @@ #include <stdlib.h> +#include <stdio.h> #include <string.h> typedef unsigned char uchar; @@ -21,13 +22,31 @@ enum { WinMask = WinSize - 1 /* window pos (index) mask */ }; +/* return values */ enum { FlateOk = 0, - FlateInError = -1, - FlateOutError = -2, - FlateCorrupted = -3 + FlateError = -1, + FlateNeedInput = -2, + FlateHasOutput = -3 +}; + +/* states */ +enum { + BlockHead, + UncompressedBlock, + CopyUncompressed, + FixedHuff, + DynamicHuff, + DynamicHuffClen, + DynamicHuffLitlenDist, + DecodeBlock, + DecodeBlockLenBits, + DecodeBlockDist, + DecodeBlockDistBits, + DecodeBlockCopy }; + typedef struct { short len; /* code length */ ushort sym; /* symbol */ @@ -43,25 +62,29 @@ typedef struct { } Huff; typedef struct { - int (*r)(void *, int, void *); - void *rdata; - uchar *srcbegin; - uchar *srcend; uchar *src; /* input buffer pointer */ + uchar *srcend; uint bits; uint nbits; - int (*w)(void *, int, void *); - void *wdata; uchar *win; /* output window */ uint pos; /* window pos */ - int error; - + int state; /* decode state */ + int final; /* last block flag */ + + /* for decoding dynamic code trees */ + int nlit; + int ndist; + int nclen; /* also used for saving decoded symbol */ + int lenpos; /* also used for copy length */ + uchar lens[Nlitlen + Ndist]; + + int fixed; /* fixed code tree flag */ Huff lhuff; /* dynamic lit/len huffman code tree */ Huff dhuff; /* dynamic distance huffman code tree */ -} Stream; +} State; /* TODO: these globals are initialized in a lazy way (not thread safe) */ static Huff lhuff; /* fixed lit/len huffman code tree */ @@ -120,7 +143,7 @@ static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) { return -1; count[0] = 0; - /* bound code lengths, force nbits to be within code lengths */ + /* bound code lengths, force nbits to be within the bounds */ for (max = CodeBits - 1; max > 0; max--) if (count[max] != 0) break; @@ -198,347 +221,449 @@ static void init_fixed_huffs(void) { build_huff(&dhuff, lens, Ndist, 5); } -/* check src before accessing it */ -static int checksrc(Stream *s) { - if (s->src == s->srcend) { - s->src = s->srcend = s->srcbegin; - s->srcend += s->r(s->src, SrcSize, s->rdata); - if (s->src >= s->srcend) { - s->error = FlateInError; +static int fillbits(State *s, int n) { + while (s->nbits < n) { + if (s->src == s->srcend) return 0; - } + s->bits |= *s->src++ << s->nbits; + s->nbits += 8; } return 1; } -/* check window position before writing to it */ -static int checkpos(Stream *s) { - if (s->pos == WinSize) { - s->pos = 0; - if (s->w(s->win, WinSize, s->wdata) != WinSize) { - s->error = FlateOutError; - return 0; - } - } - return 1; -} - -/* flush output window */ -static void flush_win(Stream *s) { - if (s->pos) - if (s->w(s->win, s->pos, s->wdata) != s->pos) - s->error = FlateOutError; -} - -/* get one bit from stream */ -static uint getbit(Stream *s) { - uint bit; - - if (!s->nbits--) { - if (!checksrc(s)) - return 0; - s->bits = *s->src++; - s->nbits = 7; - } - bit = s->bits & 1; - s->bits >>= 1; - return bit; -} - -/* get n bits from stream */ -static uint getbits(Stream *s, int n) { - uint bits; - uint nbits; +static uint getbits(State *s, int n) { + uint k; - if (n == 0) - return 0; - bits = s->bits; - nbits = s->nbits; - while (nbits < n) { - if (!checksrc(s)) - return 0; - bits |= (*s->src++ << nbits); - nbits += 8; - } - s->bits = bits >> n; - s->nbits = nbits - n; - return bits & ((1 << n) - 1); + k = s->bits & ((1 << n) - 1); + s->bits >>= n; + s->nbits -= n; + return k; } /* decode a symbol from stream with huff code */ -static uint decode_symbol(Stream *s, Huff *huff) { +static uint decode_symbol(State *s, Huff *huff) { uint huffbits = huff->nbits; - uint streambits = s->nbits; + uint nbits = s->nbits; uint bits = s->bits; uint mask = (1 << huffbits) - 1; Entry entry; /* get enough bits efficiently */ - if (streambits < huffbits) { - if (s->src + 2 < s->srcend) { + if (nbits < huffbits) { + uchar *src = s->src; + + if (src + 2 < s->srcend) { /* we assume huffbits <= 9 */ - bits |= *s->src++ << streambits; - streambits += 8; - bits |= *s->src++ << streambits; - streambits += 8; - bits |= *s->src++ << streambits; - streambits += 8; - } else - /* TODO: here we assume EOB length >= huffbits */ + bits |= *src++ << nbits; + nbits += 8; + bits |= *src++ << nbits; + nbits += 8; + bits |= *src++ << nbits; + nbits += 8; + s->src = src; + } else /* rare */ do { - if (!checksrc(s)) - return 0; - bits |= *s->src++ << streambits; - streambits += 8; - } while (streambits < huffbits); + if (s->src == s->srcend) { + entry = huff->table[bits & mask]; + if (entry.len > 0 && entry.len <= nbits) { + s->bits = bits >> entry.len; + s->nbits = nbits - entry.len; + return entry.sym; + } + return FlateNeedInput; + } + bits |= *s->src++ << nbits; + nbits += 8; + } while (nbits < huffbits); } entry = huff->table[bits & mask]; if (entry.len > 0) { s->bits = bits >> entry.len; - s->nbits = streambits - entry.len; + s->nbits = nbits - entry.len; return entry.sym; - } else if (entry.len == 0) { - s->error = FlateCorrupted; - return 0; - } - s->bits = bits >> huffbits; - s->nbits = streambits - huffbits; + } else if (entry.len == 0) + return FlateError; /* code is longer than huffbits: bitwise decode the rest */ { int cur = entry.sym; int sum = huff->sum; /* TODO: count[0..huffbits] is never needed */ ushort *count = huff->count + huffbits + 1; + int needinput = 0; + /* save bits if we are near the end */ + if (s->src + 2 >= s->srcend) { + while (nbits < CodeBits - 1 && s->src < s->srcend) { + bits |= *s->src++ << nbits; + nbits += 8; + } + s->bits = bits; + s->nbits = nbits; + needinput = 1; + } + bits >>= huffbits; + nbits -= huffbits; for (;;) { - cur |= getbit(s); + if (!nbits--) { + if (needinput) + return FlateNeedInput; + bits = *s->src++; + nbits = 7; + } + cur |= bits & 1; + bits >>= 1; sum += *count; cur -= *count; if (cur < 0) break; cur <<= 1; count++; - if (count == huff->count + CodeBits) { - s->error = FlateCorrupted; - return 0; - } + if (count == huff->count + CodeBits) + return FlateError; } + s->bits = bits; + s->nbits = nbits; return huff->symbol[sum + cur]; } } -/* decode dynamic huffman code trees from stream */ -static void decode_huffs(Stream *s) { - Huff clhuff; - uchar lens[Nlitlen+Ndist]; - uint nlit, ndist, nclen; - uint i; - - nlit = 257 + getbits(s, 5); - ndist = 1 + getbits(s, 5); - nclen = 4 + getbits(s, 4); - if (s->error != FlateOk) - return; - if (nlit > Nlitlen || ndist > Ndist) { - s->error = FlateCorrupted; - return; - } - - /* build code length tree */ - for (i = 0; i < Nclen; i++) - lens[i] = 0; - for (i = 0; i < nclen; i++) - lens[clenorder[i]] = getbits(s, 3); - if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) { - s->error = FlateCorrupted; - return; - } - - /* decode code lengths for the dynamic trees */ - for (i = 0; i < nlit + ndist; ) { - uint sym = decode_symbol(s, &clhuff); - uint len; - uchar c; - - if (sym < 16) { - lens[i++] = sym; - } else if (sym == 16) { - /* copy previous code length 3-6 times */ - c = lens[i - 1]; - for (len = 3 + getbits(s, 2); len; len--) - lens[i++] = c; - } else if (sym == 17) { - /* repeat 0 for 3-10 times */ - for (len = 3 + getbits(s, 3); len; len--) - lens[i++] = 0; - } else if (sym == 18) { - /* repeat 0 for 11-138 times */ - for (len = 11 + getbits(s, 7); len; len--) - lens[i++] = 0; - } else - s->error = FlateCorrupted; - if (s->error != FlateOk) - return; - } - /* build dynamic huffman code trees */ - if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0) - s->error = FlateCorrupted; - if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0) - s->error = FlateCorrupted; -} /* decode a block of data from stream with trees */ -static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) { +static int decode_block(State *s, Huff *lhuff, Huff *dhuff) { + uchar *win = s->win; + uint pos = s->pos; + uint sym; + uint len; + uint dist; + + switch (s->state) { + case DecodeBlockLenBits: + sym = s->nclen; + goto decode_lenbits; + case DecodeBlockDist: + len = s->lenpos; + goto decode_dist; + case DecodeBlockDistBits: + sym = s->nclen; + len = s->lenpos; + goto decode_distbits; + case DecodeBlockCopy: + dist = s->nclen; + len = s->lenpos; + goto decode_copy; + } for (;;) { - uint sym = decode_symbol(s, lhuff); + sym = decode_symbol(s, lhuff); - if (s->error != FlateOk) - return; if (sym < 256) { - s->win[s->pos++] = sym; - if (!checkpos(s)) - return; + win[pos++] = sym; + if (pos == WinSize) { + s->pos = pos; + s->state = DecodeBlock; + return FlateHasOutput; + } } else if (sym > 256) { - uint len, dist; - sym -= 257; if (sym >= Nlen) { - s->error = FlateCorrupted; - return; + s->pos = pos; + s->state = DecodeBlock; + if (sym + 257 == (uint)FlateNeedInput) + return FlateNeedInput; + return FlateError; + } +decode_lenbits: + if (!fillbits(s, lenbits[sym])) { + s->nclen = sym; /* using nclen to store sym */ + s->pos = pos; + s->state = DecodeBlockLenBits; + return FlateNeedInput; } len = lenbase[sym] + getbits(s, lenbits[sym]); +decode_dist: sym = decode_symbol(s, dhuff); - if (s->error != FlateOk) - return; - if (sym >= Ndist) { - s->error = FlateCorrupted; - return; + if (sym == (uint)FlateNeedInput) { + s->pos = pos; + s->lenpos = len; + s->state = DecodeBlockDist; + return FlateNeedInput; } + if (sym >= Ndist) + return FlateError; +decode_distbits: + if (!fillbits(s, distbits[sym])) { + s->nclen = sym; /* using nclen to store sym */ + s->pos = pos; + s->lenpos = len; + s->state = DecodeBlockDistBits; + return FlateNeedInput; + } + /* TODO: s/dist/sym/ */ dist = distbase[sym] + getbits(s, distbits[sym]); - if (s->error != FlateOk) - return; /* copy match, loop unroll in common case */ - if (s->pos + len <= WinSize) { - uint pos = s->pos; - + if (pos + len < WinSize) { /* lenbase[sym] >= 3 */ do { - s->win[pos] = s->win[(pos - dist) & WinMask]; + win[pos] = win[(pos - dist) & WinMask]; pos++; - s->win[pos] = s->win[(pos - dist) & WinMask]; + win[pos] = win[(pos - dist) & WinMask]; pos++; - s->win[pos] = s->win[(pos - dist) & WinMask]; + win[pos] = win[(pos - dist) & WinMask]; pos++; len -= 3; } while (len >= 3); if (len--) { - s->win[pos] = s->win[(pos - dist) & WinMask]; + win[pos] = win[(pos - dist) & WinMask]; pos++; if (len) { - s->win[pos] = s->win[(pos - dist) & WinMask]; + win[pos] = win[(pos - dist) & WinMask]; pos++; } } - s->pos = pos; - if (!checkpos(s)) - return; - } else + } else { +decode_copy: while (len--) { - s->win[s->pos] = s->win[(s->pos - dist) & WinMask]; - s->pos++; - if (!checkpos(s)) - return; + win[pos] = win[(pos - dist) & WinMask]; + pos++; + if (pos == WinSize) { + s->pos = pos; + s->lenpos = len; + s->nclen = dist; + s->state = DecodeBlockCopy; + return FlateHasOutput; + } } - } else /* EOB: sym == 256 */ - return; + } + } else { /* EOB: sym == 256 */ + s->pos = pos; + return FlateOk; + } } } -/* TODO: untested, slow code */ -static void inflate_uncompressed_block(Stream *s) { - uint len, invlen; - - /* start next block on a byte boundary */ - s->bits >>= s->nbits & 7; - s->nbits &= ~7; - - len = getbits(s, 16); - invlen = getbits(s, 16); - /* s->nbits should be 0 here */ - if (s->error != FlateOk) - return; - if (len != (~invlen & 0xffff)) { - s->error = FlateCorrupted; - return; - } - while (len--) { - s->win[s->pos++] = getbits(s, 8); - checkpos(s); - } -} -static void inflate_fixed_block(Stream *s) { - /* lazy initialization of fixed huff code trees */ +/* extern */ + +int inflate_init(State *s) { + /* TODO */ if (lhuff.nbits == 0) init_fixed_huffs(); - decode_block(s, &lhuff, &dhuff); -} - -static void inflate_dynamic_block(Stream *s) { - decode_huffs(s); - if (s->error != FlateOk) - return; - decode_block(s, &s->lhuff, &s->dhuff); + s->pos = 0; + s->win = malloc(WinSize); + if (!s->win) + return FlateError; + s->bits = 0; + s->nbits = 0; + + s->state = BlockHead; + s->final = 0; + return FlateOk; } +/* inflate, returns: short src, short dst, error, ok */ +int inflate(State *s) { + int n; -/* extern */ - -/* inflate with callbacks */ -int inflate(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) { - Stream s; - uint final; - - s.r = r; - s.rdata = rdata; - s.w = w; - s.wdata = wdata; - s.srcbegin = s.srcend = s.src = malloc(SrcSize); - s.win = malloc(WinSize); - s.pos = 0; - s.error = FlateOk; - s.nbits = 0; - do { - uint blocktype; - - final = getbit(&s); - blocktype = getbits(&s, 2); - if (s.error != FlateOk) - return 0; - - /* decompress block */ - switch (blocktype) { - case 0: - inflate_uncompressed_block(&s); + for (;;) { + switch (s->state) { + case BlockHead: + if (s->final) { + if (s->pos) + goto hasoutput; + else + goto finish; + } + if (!fillbits(s, 3)) + goto needinput; + s->final = getbits(s, 1); + n = getbits(s, 2); + if (n == 0) + s->state = UncompressedBlock; + else if (n == 1) + s->state = FixedHuff; + else if (n == 2) + s->state = DynamicHuff; + else + goto error; break; - case 1: - inflate_fixed_block(&s); + case UncompressedBlock: + /* start block on a byte boundary */ + s->bits >>= s->nbits & 7; + s->nbits &= ~7; + if (!fillbits(s, 32)) + goto needinput; + s->lenpos = getbits(s, 16); + n = getbits(s, 16); + if (s->lenpos != (~n & 0xffff)) + goto error; + s->state = CopyUncompressed; + case CopyUncompressed: + /* TODO: untested, slow */ + /* s->nbits should be 0 here */ + while (s->lenpos) { + if (s->src == s->srcend) + goto needinput; + s->lenpos--; + s->win[s->pos++] = *s->src++; + if (s->pos == WinSize) + goto hasoutput; + } + s->state = BlockHead; break; - case 2: - inflate_dynamic_block(&s); + case FixedHuff: + if (lhuff.nbits == 0) /* lazy init */ + init_fixed_huffs(); + s->fixed = 1; + s->state = DecodeBlock; + break; + case DynamicHuff: + /* decode dynamic huffman code trees */ + if (!fillbits(s, 14)) + goto needinput; + s->nlit = 257 + getbits(s, 5); + s->ndist = 1 + getbits(s, 5); + s->nclen = 4 + getbits(s, 4); + if (s->nlit > Nlitlen || s->ndist > Ndist) + goto error; + /* build code length tree */ + for (n = 0; n < Nclen; n++) + s->lens[n] = 0; + s->fixed = 0; + s->state = DynamicHuffClen; + s->lenpos = 0; + case DynamicHuffClen: + for (n = s->lenpos; n < s->nclen; n++) + if (fillbits(s, 3)) { + s->lens[clenorder[n]] = getbits(s, 3); + } else { + s->lenpos = n; + goto needinput; + } + /* using lhuff for code length huff code */ + if (build_huff(&s->lhuff, s->lens, Nclen, ClenTableBits) < 0) + goto error; + s->state = DynamicHuffLitlenDist; + s->lenpos = 0; + s->nclen = -1; /* decoded symbol is stored in clen or -1 */ + case DynamicHuffLitlenDist: + if (s->nclen >= 0) + goto dynhuff_continue; + /* decode code lengths for the dynamic trees */ + for (n = s->lenpos; n < s->nlit + s->ndist; ) { + uint sym = decode_symbol(s, &s->lhuff); + uint len; + uchar c; + + if (sym < 16) { + s->lens[n++] = sym; + continue; + } else if (sym == (uint)FlateNeedInput) { + s->nclen = -1; + s->lenpos = n; + goto needinput; +dynhuff_continue: + n = s->lenpos; + sym = s->nclen; + } + if (!fillbits(s, 7)) { + /* TODO: 7 is too much */ + if (sym == (uint)FlateError) + goto error; + s->nclen = sym; + s->lenpos = n; + goto needinput; + } + if (sym == 16) { + /* copy previous code length 3-6 times */ + c = s->lens[n - 1]; + for (len = 3 + getbits(s, 2); len; len--) + s->lens[n++] = c; + } else if (sym == 17) { + /* repeat 0 for 3-10 times */ + for (len = 3 + getbits(s, 3); len; len--) + s->lens[n++] = 0; + } else if (sym == 18) { + /* repeat 0 for 11-138 times */ + for (len = 11 + getbits(s, 7); len; len--) + s->lens[n++] = 0; + } else + goto error; + } + /* build dynamic huffman code trees */ + if (build_huff(&s->lhuff, s->lens, s->nlit, LitlenTableBits) < 0) + goto error; + if (build_huff(&s->dhuff, s->lens + s->nlit, s->ndist, DistTableBits) < 0) + goto error; + s->state = DecodeBlock; + case DecodeBlock: + case DecodeBlockLenBits: + case DecodeBlockDist: + case DecodeBlockDistBits: + case DecodeBlockCopy: + if (s->fixed) + n = decode_block(s, &lhuff, &dhuff); + else + n = decode_block(s, &s->lhuff, &s->dhuff); + if (n == FlateNeedInput) + goto needinput; + if (n == FlateHasOutput) + goto hasoutput; + if (n == FlateError) + goto error; + s->state = BlockHead; break; default: - s.error = FlateCorrupted; + goto error; } - if (s.error != FlateOk) - break; - } while (!final); - flush_win(&s); - free(s.win); - free(s.srcbegin); - return s.error; + } +needinput: + return FlateNeedInput; +hasoutput: + return FlateHasOutput; +error: + free(s->win); + return FlateError; +finish: + free(s->win); + return FlateOk; } +int inflate_callback(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) { + State s; + uchar *src; + int len; + + if (inflate_init(&s) != FlateOk) + return FlateError; + s.src = s.srcend = src = malloc(4096); + if (src == NULL) + return FlateError; + for (;;) + switch (inflate(&s)) { + case FlateNeedInput: + len = r(src, 4096, rdata); + if (len <= 0) { + free(src); + return FlateError; + } + s.src = src; + s.srcend = src + len; + break; + case FlateHasOutput: + len = w(s.win, s.pos, wdata); + if (len != s.pos) { + free(src); + return FlateError; + } + s.pos = 0; /* ouch */ + break; + case FlateOk: + free(src); + return FlateOk; + case FlateError: + free(src); + return FlateError; + } +} /* simple test */ @@ -575,7 +700,8 @@ int main() { in.n = 0; out.f = stdout; out.n = 0; - err = inflate(r, &in, w, &out); + err = inflate_callback(r, &in, w, &out); fprintf(stderr, "error: %d, decompressed %u bytes\n", err, out.n); return 0; } + diff --git a/inflate_callback.c b/inflate_callback.c @@ -0,0 +1,590 @@ +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +typedef unsigned char uchar; +typedef unsigned short ushort; +typedef unsigned int uint; + +enum { + CodeBits = 16, /* max number of bits in a code + 1 */ + LitlenTableBits = 9, /* litlen code bits used in lookup table */ + DistTableBits = 6, /* dist code bits used in lookup table */ + ClenTableBits = 6, /* clen code bits used in lookup table */ + TableBits = LitlenTableBits, /* log2(lookup table size) */ + Nlit = 256, /* number of lit codes */ + Nlen = 29, /* number of len codes */ + Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */ + Ndist = 30, /* number of distance codes */ + Nclen = 19, /* number of code length codes */ + SrcSize = 1 << 12, /* input buffer size */ + WinSize = 1 << 15, /* output window size */ + WinMask = WinSize - 1 /* window pos (index) mask */ +}; + +enum { + FlateOk = 0, + FlateInError = -1, + FlateOutError = -2, + FlateCorrupted = -3 +}; + +typedef struct { + short len; /* code length */ + ushort sym; /* symbol */ +} Entry; + +/* huffman code tree */ +typedef struct { + Entry table[1 << TableBits]; /* prefix lookup table */ + uint nbits; /* prefix length (table size is 1 << nbits) */ + uint sum; /* full codes in table: sum(count[0..nbits]) */ + ushort count[CodeBits]; /* number of codes with given length */ + ushort symbol[Nlitlen]; /* symbols ordered by code length (lexic.) */ +} Huff; + +typedef struct { + int (*r)(void *, int, void *); + void *rdata; + uchar *srcbegin; + uchar *srcend; + uchar *src; /* input buffer pointer */ + + uint bits; + uint nbits; + + int (*w)(void *, int, void *); + void *wdata; + uchar *win; /* output window */ + uint pos; /* window pos */ + + int error; + + Huff lhuff; /* dynamic lit/len huffman code tree */ + Huff dhuff; /* dynamic distance huffman code tree */ +} Stream; + +/* TODO: these globals are initialized in a lazy way (not thread safe) */ +static Huff lhuff; /* fixed lit/len huffman code tree */ +static Huff dhuff; /* fixed distance huffman code tree */ + +/* base offset and extra bits tables */ +static uchar lenbits[Nlen] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 +}; +static ushort lenbase[Nlen] = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 +}; +static uchar distbits[Ndist] = { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 +}; +static ushort distbase[Ndist] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 +}; + +/* ordering of code lengths */ +static uchar clenorder[Nclen] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 +}; + +/* TODO: this or normal inc + reverse() */ +/* increment bitwise reversed n (msb is bit 0, lsb is bit len-1) */ +static uint revinc(uint n, uint len) { + uint i = 1 << (len - 1); + + while (n & i) + i >>= 1; + if (i) { + n &= i - 1; + n |= i; + } else + n = 0; + return n; +} + +/* build huffman code tree from code lengths (each should be < CodeBits) */ +static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) { + int offs[CodeBits]; + int left; + uint i, c, sum, code, len, min, max; + ushort *count = huff->count; + ushort *symbol = huff->symbol; + Entry *table = huff->table; + Entry entry; + + /* count code lengths */ + for (i = 0; i < CodeBits; i++) + count[i] = 0; + for (i = 0; i < n; i++) + count[lens[i]]++; + if (count[0] == n) + return -1; + count[0] = 0; + + /* bound code lengths, force nbits to be within the bounds */ + for (max = CodeBits - 1; max > 0; max--) + if (count[max] != 0) + break; + if (nbits > max) + nbits = max; + for (min = 1; min < CodeBits; min++) + if (count[min] != 0) + break; + if (nbits < min) + nbits = min; + huff->nbits = nbits; + + /* check if length is over-subscribed or incomplete */ + for (left = 1 << min, i = min; i <= max; left <<= 1, i++) { + left -= count[i]; + /* left < 0: over-subscribed, left > 0: incomplete */ + if (left < 0) + return -1; + } + + for (sum = 0, i = 0; i <= max; i++) { + offs[i] = sum; + sum += count[i]; + } + /* needed for decoding codes longer than nbits */ + if (nbits < max) + huff->sum = offs[nbits + 1]; + + /* sort symbols by code length (lexicographic order) */ + for (i = 0; i < n; i++) + if (lens[i]) + symbol[offs[lens[i]]++] = i; + + /* lookup table for decoding nbits from input.. */ + for (i = 0; i < 1 << nbits; i++) + table[i].len = 0; /* invalid marker for incomplete code */ + code = 0; + /* ..if code is at most nbits (bits are in reverse order, sigh..) */ + for (len = min; len <= nbits; len++) + for (c = count[len]; c > 0; c--) { + entry.len = len; + entry.sym = *symbol; + for (i = code; i < 1 << nbits; i += 1 << len) + table[i] = entry; + /* next code */ + symbol++; + code = revinc(code, len); + } + /* ..if code is longer than nbits: values for simple bitwise decode */ + for (i = 0; code; i++) { + table[code].len = -1; + table[code].sym = i << 1; + code = revinc(code, nbits); + } + return 0; +} + +/* fixed huffman code trees (should be done at compile time..) */ +static void init_fixed_huffs(void) { + int i; + uchar lens[Nlitlen]; + + for (i = 0; i < 144; i++) + lens[i] = 8; + for (; i < 256; i++) + lens[i] = 9; + for (; i < 280; i++) + lens[i] = 7; + for (; i < Nlitlen; i++) + lens[i] = 8; + build_huff(&lhuff, lens, Nlitlen, 8); + + for (i = 0; i < Ndist; i++) + lens[i] = 5; + build_huff(&dhuff, lens, Ndist, 5); +} + +/* check src before accessing it */ +static int checksrc(Stream *s) { + if (s->src == s->srcend) { + s->src = s->srcend = s->srcbegin; + s->srcend += s->r(s->src, SrcSize, s->rdata); + if (s->src >= s->srcend) { + s->error = FlateInError; + return 0; + } + } + return 1; +} + +/* check window position before writing to it */ +static int checkpos(Stream *s) { + if (s->pos == WinSize) { + s->pos = 0; + if (s->w(s->win, WinSize, s->wdata) != WinSize) { + s->error = FlateOutError; + return 0; + } + } + return 1; +} + +/* flush output window */ +static void flush_win(Stream *s) { + if (s->pos) + if (s->w(s->win, s->pos, s->wdata) != s->pos) + s->error = FlateOutError; +} + +/* get one bit from stream */ +static uint getbit(Stream *s) { + uint bit; + + if (!s->nbits--) { + if (!checksrc(s)) + return 0; + s->bits = *s->src++; + s->nbits = 7; + } + bit = s->bits & 1; + s->bits >>= 1; + return bit; +} + +/* get n bits from stream */ +static uint getbits(Stream *s, int n) { + uint bits; + uint nbits; + + if (n == 0) + return 0; + bits = s->bits; + nbits = s->nbits; + while (nbits < n) { + if (!checksrc(s)) + return 0; + bits |= (*s->src++ << nbits); + nbits += 8; + } + s->bits = bits >> n; + s->nbits = nbits - n; + return bits & ((1 << n) - 1); +} + +/* decode a symbol from stream with huff code */ +static uint decode_symbol(Stream *s, Huff *huff) { + uint huffbits = huff->nbits; + uint nbits = s->nbits; + uint bits = s->bits; + uint mask = (1 << huffbits) - 1; + Entry entry; + + /* get enough bits efficiently */ + if (nbits < huffbits) { + if (s->src + 2 < s->srcend) { + /* we assume huffbits <= 9 */ + bits |= *s->src++ << nbits; + nbits += 8; + bits |= *s->src++ << nbits; + nbits += 8; + bits |= *s->src++ << nbits; + nbits += 8; + } else + do { + if (!checksrc(s)) { + /* EOB might be < huffbits */ + entry = huff->table[bits & mask]; + if (entry.len > 0 && entry.len <= nbits) { + s->bits = bits >> entry.len; + s->nbits = nbits - entry.len; + s->error = FlateOk; + return entry.sym; + } + return 0; + } + bits |= *s->src++ << nbits; + nbits += 8; + } while (nbits < huffbits); + } + entry = huff->table[bits & mask]; + if (entry.len > 0) { + s->bits = bits >> entry.len; + s->nbits = nbits - entry.len; + return entry.sym; + } else if (entry.len == 0) { + s->error = FlateCorrupted; + return 0; + } + s->bits = bits >> huffbits; + s->nbits = nbits - huffbits; + /* code is longer than huffbits: bitwise decode the rest */ + { + int cur = entry.sym; + int sum = huff->sum; + /* TODO: count[0..huffbits] is never needed */ + ushort *count = huff->count + huffbits + 1; + + for (;;) { + cur |= getbit(s); + sum += *count; + cur -= *count; + if (cur < 0) + break; + cur <<= 1; + count++; + if (count == huff->count + CodeBits) { + s->error = FlateCorrupted; + return 0; + } + } + return huff->symbol[sum + cur]; + } +} + +/* decode dynamic huffman code trees from stream */ +static void decode_huffs(Stream *s) { + Huff clhuff; + uchar lens[Nlitlen+Ndist]; + uint nlit, ndist, nclen; + uint i; + + nlit = 257 + getbits(s, 5); + ndist = 1 + getbits(s, 5); + nclen = 4 + getbits(s, 4); + if (s->error != FlateOk) + return; + if (nlit > Nlitlen || ndist > Ndist) { + s->error = FlateCorrupted; + return; + } + + /* build code length tree */ + for (i = 0; i < Nclen; i++) + lens[i] = 0; + for (i = 0; i < nclen; i++) + lens[clenorder[i]] = getbits(s, 3); + if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) { + s->error = FlateCorrupted; + return; + } + + /* decode code lengths for the dynamic trees */ + for (i = 0; i < nlit + ndist; ) { + uint sym = decode_symbol(s, &clhuff); + uint len; + uchar c; + + if (sym < 16) { + lens[i++] = sym; + } else if (sym == 16) { + /* copy previous code length 3-6 times */ + c = lens[i - 1]; + for (len = 3 + getbits(s, 2); len; len--) + lens[i++] = c; + } else if (sym == 17) { + /* repeat 0 for 3-10 times */ + for (len = 3 + getbits(s, 3); len; len--) + lens[i++] = 0; + } else if (sym == 18) { + /* repeat 0 for 11-138 times */ + for (len = 11 + getbits(s, 7); len; len--) + lens[i++] = 0; + } else + s->error = FlateCorrupted; + if (s->error != FlateOk) + return; + } + /* build dynamic huffman code trees */ + if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0) + s->error = FlateCorrupted; + if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0) + s->error = FlateCorrupted; +} + +/* decode a block of data from stream with trees */ +static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) { + for (;;) { + uint sym = decode_symbol(s, lhuff); + + if (s->error != FlateOk) + return; + if (sym < 256) { + s->win[s->pos++] = sym; + if (!checkpos(s)) + return; + } else if (sym > 256) { + uint len, dist; + + sym -= 257; + if (sym >= Nlen) { + s->error = FlateCorrupted; + return; + } + len = lenbase[sym] + getbits(s, lenbits[sym]); + sym = decode_symbol(s, dhuff); + if (s->error != FlateOk) + return; + if (sym >= Ndist) { + s->error = FlateCorrupted; + return; + } + dist = distbase[sym] + getbits(s, distbits[sym]); + if (s->error != FlateOk) + return; + /* copy match, loop unroll in common case */ + if (s->pos + len <= WinSize) { + uint pos = s->pos; + + /* lenbase[sym] >= 3 */ + do { + s->win[pos] = s->win[(pos - dist) & WinMask]; + pos++; + s->win[pos] = s->win[(pos - dist) & WinMask]; + pos++; + s->win[pos] = s->win[(pos - dist) & WinMask]; + pos++; + len -= 3; + } while (len >= 3); + if (len--) { + s->win[pos] = s->win[(pos - dist) & WinMask]; + pos++; + if (len) { + s->win[pos] = s->win[(pos - dist) & WinMask]; + pos++; + } + } + s->pos = pos; + if (!checkpos(s)) + return; + } else + while (len--) { + s->win[s->pos] = s->win[(s->pos - dist) & WinMask]; + s->pos++; + if (!checkpos(s)) + return; + } + } else /* EOB: sym == 256 */ + return; + } +} + +/* TODO: untested, slow code */ +static void inflate_uncompressed_block(Stream *s) { + uint len, invlen; + + /* start next block on a byte boundary */ + s->bits >>= s->nbits & 7; + s->nbits &= ~7; + + len = getbits(s, 16); + invlen = getbits(s, 16); + /* s->nbits should be 0 here */ + if (s->error != FlateOk) + return; + if (len != (~invlen & 0xffff)) { + s->error = FlateCorrupted; + return; + } + while (len--) { + s->win[s->pos++] = getbits(s, 8); + checkpos(s); + } +} + +static void inflate_fixed_block(Stream *s) { + /* lazy initialization of fixed huff code trees */ + if (lhuff.nbits == 0) + init_fixed_huffs(); + decode_block(s, &lhuff, &dhuff); +} + +static void inflate_dynamic_block(Stream *s) { + decode_huffs(s); + if (s->error != FlateOk) + return; + decode_block(s, &s->lhuff, &s->dhuff); +} + + +/* extern */ + +/* inflate with callbacks */ +int inflate(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) { + Stream s; + uint final; + + s.r = r; + s.rdata = rdata; + s.w = w; + s.wdata = wdata; + s.srcbegin = s.srcend = s.src = malloc(SrcSize); + s.win = malloc(WinSize); + s.pos = 0; + s.error = FlateOk; + s.nbits = 0; + do { + uint blocktype; + + final = getbit(&s); + blocktype = getbits(&s, 2); + if (s.error != FlateOk) + return 0; + + /* decompress block */ + switch (blocktype) { + case 0: + inflate_uncompressed_block(&s); + break; + case 1: + inflate_fixed_block(&s); + break; + case 2: + inflate_dynamic_block(&s); + break; + default: + s.error = FlateCorrupted; + } + if (s.error != FlateOk) + break; + } while (!final); + flush_win(&s); + free(s.win); + free(s.srcbegin); + return s.error; +} + + +/* simple test */ + +#include <stdio.h> + +struct data { + FILE *f; + uint n; +}; + +int r(void *p, int siz, void *data) { + uint n; + struct data *d = data; + + n = fread(p, 1, siz, d->f); + d->n += n; + return n; +} + +int w(void *p, int siz, void *data) { + uint n; + struct data *d = data; + + n = fwrite(p, 1, siz, d->f); + d->n += n; + return n; +} + +int main() { + struct data in, out; + int err; + + in.f = stdin; + in.n = 0; + out.f = stdout; + out.n = 0; + err = inflate(r, &in, w, &out); + fprintf(stderr, "error: %d, decompressed %u bytes\n", err, out.n); + return 0; +}