commit db64c049e9f97322b457aca518b1899ce5444dcf
parent 180021497c5184b5eb3f713cb907266ac3407b46
Author: nsz <nszabolcs@gmail.com>
Date: Sun, 7 Jun 2009 11:41:19 +0200
inflate with state
Diffstat:
Makefile | | | 13 | ++++++++----- |
inflate.c | | | 664 | +++++++++++++++++++++++++++++++++++++++++++++++-------------------------------- |
inflate_callback.c | | | 590 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
3 files changed, 993 insertions(+), 274 deletions(-)
diff --git a/Makefile b/Makefile
@@ -1,24 +1,27 @@
#CFLAGS=-g -Wall -ansi -pedantic
-CFLAGS=-Os -Wall -ansi -pedantic
+CFLAGS=-O3 -Wall -ansi -pedantic
LDFLAGS=
-SRC=inflate.c inflate_simple.c
+SRC=inflate.c inflate_simple.c inflate_callback.c
OBJ=${SRC:.c=.o}
+EXE=${SRC:.c=}
-all: inflate inflate_simple
+all: ${EXE}
inflate: inflate.o
${CC} -o $@ $^ ${LDFLAGS}
inflate_simple: inflate_simple.o
${CC} -o $@ $^ ${LDFLAGS}
+inflate_callback: inflate_callback.o
+ ${CC} -o $@ $^ ${LDFLAGS}
${OBJ}: Makefile
.c.o:
${CC} -c ${CFLAGS} $<
clean:
- rm -f ${OBJ} inflate
+ rm -f ${OBJ} ${EXE}
prof: inflate.c
gcc -fprofile-arcs -ftest-coverage -pg -g -Wall $<
- cat d.dat | ./a.out > /dev/null
+ cat a.dat | ./a.out > /dev/null
gcov -b $< > /dev/null
gprof a.out > $<.gprof
gcc -g -Wall $<
diff --git a/inflate.c b/inflate.c
@@ -1,4 +1,5 @@
#include <stdlib.h>
+#include <stdio.h>
#include <string.h>
typedef unsigned char uchar;
@@ -21,13 +22,31 @@ enum {
WinMask = WinSize - 1 /* window pos (index) mask */
};
+/* return values */
enum {
FlateOk = 0,
- FlateInError = -1,
- FlateOutError = -2,
- FlateCorrupted = -3
+ FlateError = -1,
+ FlateNeedInput = -2,
+ FlateHasOutput = -3
+};
+
+/* states */
+enum {
+ BlockHead,
+ UncompressedBlock,
+ CopyUncompressed,
+ FixedHuff,
+ DynamicHuff,
+ DynamicHuffClen,
+ DynamicHuffLitlenDist,
+ DecodeBlock,
+ DecodeBlockLenBits,
+ DecodeBlockDist,
+ DecodeBlockDistBits,
+ DecodeBlockCopy
};
+
typedef struct {
short len; /* code length */
ushort sym; /* symbol */
@@ -43,25 +62,29 @@ typedef struct {
} Huff;
typedef struct {
- int (*r)(void *, int, void *);
- void *rdata;
- uchar *srcbegin;
- uchar *srcend;
uchar *src; /* input buffer pointer */
+ uchar *srcend;
uint bits;
uint nbits;
- int (*w)(void *, int, void *);
- void *wdata;
uchar *win; /* output window */
uint pos; /* window pos */
- int error;
-
+ int state; /* decode state */
+ int final; /* last block flag */
+
+ /* for decoding dynamic code trees */
+ int nlit;
+ int ndist;
+ int nclen; /* also used for saving decoded symbol */
+ int lenpos; /* also used for copy length */
+ uchar lens[Nlitlen + Ndist];
+
+ int fixed; /* fixed code tree flag */
Huff lhuff; /* dynamic lit/len huffman code tree */
Huff dhuff; /* dynamic distance huffman code tree */
-} Stream;
+} State;
/* TODO: these globals are initialized in a lazy way (not thread safe) */
static Huff lhuff; /* fixed lit/len huffman code tree */
@@ -120,7 +143,7 @@ static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) {
return -1;
count[0] = 0;
- /* bound code lengths, force nbits to be within code lengths */
+ /* bound code lengths, force nbits to be within the bounds */
for (max = CodeBits - 1; max > 0; max--)
if (count[max] != 0)
break;
@@ -198,347 +221,449 @@ static void init_fixed_huffs(void) {
build_huff(&dhuff, lens, Ndist, 5);
}
-/* check src before accessing it */
-static int checksrc(Stream *s) {
- if (s->src == s->srcend) {
- s->src = s->srcend = s->srcbegin;
- s->srcend += s->r(s->src, SrcSize, s->rdata);
- if (s->src >= s->srcend) {
- s->error = FlateInError;
+static int fillbits(State *s, int n) {
+ while (s->nbits < n) {
+ if (s->src == s->srcend)
return 0;
- }
+ s->bits |= *s->src++ << s->nbits;
+ s->nbits += 8;
}
return 1;
}
-/* check window position before writing to it */
-static int checkpos(Stream *s) {
- if (s->pos == WinSize) {
- s->pos = 0;
- if (s->w(s->win, WinSize, s->wdata) != WinSize) {
- s->error = FlateOutError;
- return 0;
- }
- }
- return 1;
-}
-
-/* flush output window */
-static void flush_win(Stream *s) {
- if (s->pos)
- if (s->w(s->win, s->pos, s->wdata) != s->pos)
- s->error = FlateOutError;
-}
-
-/* get one bit from stream */
-static uint getbit(Stream *s) {
- uint bit;
-
- if (!s->nbits--) {
- if (!checksrc(s))
- return 0;
- s->bits = *s->src++;
- s->nbits = 7;
- }
- bit = s->bits & 1;
- s->bits >>= 1;
- return bit;
-}
-
-/* get n bits from stream */
-static uint getbits(Stream *s, int n) {
- uint bits;
- uint nbits;
+static uint getbits(State *s, int n) {
+ uint k;
- if (n == 0)
- return 0;
- bits = s->bits;
- nbits = s->nbits;
- while (nbits < n) {
- if (!checksrc(s))
- return 0;
- bits |= (*s->src++ << nbits);
- nbits += 8;
- }
- s->bits = bits >> n;
- s->nbits = nbits - n;
- return bits & ((1 << n) - 1);
+ k = s->bits & ((1 << n) - 1);
+ s->bits >>= n;
+ s->nbits -= n;
+ return k;
}
/* decode a symbol from stream with huff code */
-static uint decode_symbol(Stream *s, Huff *huff) {
+static uint decode_symbol(State *s, Huff *huff) {
uint huffbits = huff->nbits;
- uint streambits = s->nbits;
+ uint nbits = s->nbits;
uint bits = s->bits;
uint mask = (1 << huffbits) - 1;
Entry entry;
/* get enough bits efficiently */
- if (streambits < huffbits) {
- if (s->src + 2 < s->srcend) {
+ if (nbits < huffbits) {
+ uchar *src = s->src;
+
+ if (src + 2 < s->srcend) {
/* we assume huffbits <= 9 */
- bits |= *s->src++ << streambits;
- streambits += 8;
- bits |= *s->src++ << streambits;
- streambits += 8;
- bits |= *s->src++ << streambits;
- streambits += 8;
- } else
- /* TODO: here we assume EOB length >= huffbits */
+ bits |= *src++ << nbits;
+ nbits += 8;
+ bits |= *src++ << nbits;
+ nbits += 8;
+ bits |= *src++ << nbits;
+ nbits += 8;
+ s->src = src;
+ } else /* rare */
do {
- if (!checksrc(s))
- return 0;
- bits |= *s->src++ << streambits;
- streambits += 8;
- } while (streambits < huffbits);
+ if (s->src == s->srcend) {
+ entry = huff->table[bits & mask];
+ if (entry.len > 0 && entry.len <= nbits) {
+ s->bits = bits >> entry.len;
+ s->nbits = nbits - entry.len;
+ return entry.sym;
+ }
+ return FlateNeedInput;
+ }
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ } while (nbits < huffbits);
}
entry = huff->table[bits & mask];
if (entry.len > 0) {
s->bits = bits >> entry.len;
- s->nbits = streambits - entry.len;
+ s->nbits = nbits - entry.len;
return entry.sym;
- } else if (entry.len == 0) {
- s->error = FlateCorrupted;
- return 0;
- }
- s->bits = bits >> huffbits;
- s->nbits = streambits - huffbits;
+ } else if (entry.len == 0)
+ return FlateError;
/* code is longer than huffbits: bitwise decode the rest */
{
int cur = entry.sym;
int sum = huff->sum;
/* TODO: count[0..huffbits] is never needed */
ushort *count = huff->count + huffbits + 1;
+ int needinput = 0;
+ /* save bits if we are near the end */
+ if (s->src + 2 >= s->srcend) {
+ while (nbits < CodeBits - 1 && s->src < s->srcend) {
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ }
+ s->bits = bits;
+ s->nbits = nbits;
+ needinput = 1;
+ }
+ bits >>= huffbits;
+ nbits -= huffbits;
for (;;) {
- cur |= getbit(s);
+ if (!nbits--) {
+ if (needinput)
+ return FlateNeedInput;
+ bits = *s->src++;
+ nbits = 7;
+ }
+ cur |= bits & 1;
+ bits >>= 1;
sum += *count;
cur -= *count;
if (cur < 0)
break;
cur <<= 1;
count++;
- if (count == huff->count + CodeBits) {
- s->error = FlateCorrupted;
- return 0;
- }
+ if (count == huff->count + CodeBits)
+ return FlateError;
}
+ s->bits = bits;
+ s->nbits = nbits;
return huff->symbol[sum + cur];
}
}
-/* decode dynamic huffman code trees from stream */
-static void decode_huffs(Stream *s) {
- Huff clhuff;
- uchar lens[Nlitlen+Ndist];
- uint nlit, ndist, nclen;
- uint i;
-
- nlit = 257 + getbits(s, 5);
- ndist = 1 + getbits(s, 5);
- nclen = 4 + getbits(s, 4);
- if (s->error != FlateOk)
- return;
- if (nlit > Nlitlen || ndist > Ndist) {
- s->error = FlateCorrupted;
- return;
- }
-
- /* build code length tree */
- for (i = 0; i < Nclen; i++)
- lens[i] = 0;
- for (i = 0; i < nclen; i++)
- lens[clenorder[i]] = getbits(s, 3);
- if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) {
- s->error = FlateCorrupted;
- return;
- }
-
- /* decode code lengths for the dynamic trees */
- for (i = 0; i < nlit + ndist; ) {
- uint sym = decode_symbol(s, &clhuff);
- uint len;
- uchar c;
-
- if (sym < 16) {
- lens[i++] = sym;
- } else if (sym == 16) {
- /* copy previous code length 3-6 times */
- c = lens[i - 1];
- for (len = 3 + getbits(s, 2); len; len--)
- lens[i++] = c;
- } else if (sym == 17) {
- /* repeat 0 for 3-10 times */
- for (len = 3 + getbits(s, 3); len; len--)
- lens[i++] = 0;
- } else if (sym == 18) {
- /* repeat 0 for 11-138 times */
- for (len = 11 + getbits(s, 7); len; len--)
- lens[i++] = 0;
- } else
- s->error = FlateCorrupted;
- if (s->error != FlateOk)
- return;
- }
- /* build dynamic huffman code trees */
- if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0)
- s->error = FlateCorrupted;
- if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0)
- s->error = FlateCorrupted;
-}
/* decode a block of data from stream with trees */
-static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) {
+static int decode_block(State *s, Huff *lhuff, Huff *dhuff) {
+ uchar *win = s->win;
+ uint pos = s->pos;
+ uint sym;
+ uint len;
+ uint dist;
+
+ switch (s->state) {
+ case DecodeBlockLenBits:
+ sym = s->nclen;
+ goto decode_lenbits;
+ case DecodeBlockDist:
+ len = s->lenpos;
+ goto decode_dist;
+ case DecodeBlockDistBits:
+ sym = s->nclen;
+ len = s->lenpos;
+ goto decode_distbits;
+ case DecodeBlockCopy:
+ dist = s->nclen;
+ len = s->lenpos;
+ goto decode_copy;
+ }
for (;;) {
- uint sym = decode_symbol(s, lhuff);
+ sym = decode_symbol(s, lhuff);
- if (s->error != FlateOk)
- return;
if (sym < 256) {
- s->win[s->pos++] = sym;
- if (!checkpos(s))
- return;
+ win[pos++] = sym;
+ if (pos == WinSize) {
+ s->pos = pos;
+ s->state = DecodeBlock;
+ return FlateHasOutput;
+ }
} else if (sym > 256) {
- uint len, dist;
-
sym -= 257;
if (sym >= Nlen) {
- s->error = FlateCorrupted;
- return;
+ s->pos = pos;
+ s->state = DecodeBlock;
+ if (sym + 257 == (uint)FlateNeedInput)
+ return FlateNeedInput;
+ return FlateError;
+ }
+decode_lenbits:
+ if (!fillbits(s, lenbits[sym])) {
+ s->nclen = sym; /* using nclen to store sym */
+ s->pos = pos;
+ s->state = DecodeBlockLenBits;
+ return FlateNeedInput;
}
len = lenbase[sym] + getbits(s, lenbits[sym]);
+decode_dist:
sym = decode_symbol(s, dhuff);
- if (s->error != FlateOk)
- return;
- if (sym >= Ndist) {
- s->error = FlateCorrupted;
- return;
+ if (sym == (uint)FlateNeedInput) {
+ s->pos = pos;
+ s->lenpos = len;
+ s->state = DecodeBlockDist;
+ return FlateNeedInput;
}
+ if (sym >= Ndist)
+ return FlateError;
+decode_distbits:
+ if (!fillbits(s, distbits[sym])) {
+ s->nclen = sym; /* using nclen to store sym */
+ s->pos = pos;
+ s->lenpos = len;
+ s->state = DecodeBlockDistBits;
+ return FlateNeedInput;
+ }
+ /* TODO: s/dist/sym/ */
dist = distbase[sym] + getbits(s, distbits[sym]);
- if (s->error != FlateOk)
- return;
/* copy match, loop unroll in common case */
- if (s->pos + len <= WinSize) {
- uint pos = s->pos;
-
+ if (pos + len < WinSize) {
/* lenbase[sym] >= 3 */
do {
- s->win[pos] = s->win[(pos - dist) & WinMask];
+ win[pos] = win[(pos - dist) & WinMask];
pos++;
- s->win[pos] = s->win[(pos - dist) & WinMask];
+ win[pos] = win[(pos - dist) & WinMask];
pos++;
- s->win[pos] = s->win[(pos - dist) & WinMask];
+ win[pos] = win[(pos - dist) & WinMask];
pos++;
len -= 3;
} while (len >= 3);
if (len--) {
- s->win[pos] = s->win[(pos - dist) & WinMask];
+ win[pos] = win[(pos - dist) & WinMask];
pos++;
if (len) {
- s->win[pos] = s->win[(pos - dist) & WinMask];
+ win[pos] = win[(pos - dist) & WinMask];
pos++;
}
}
- s->pos = pos;
- if (!checkpos(s))
- return;
- } else
+ } else {
+decode_copy:
while (len--) {
- s->win[s->pos] = s->win[(s->pos - dist) & WinMask];
- s->pos++;
- if (!checkpos(s))
- return;
+ win[pos] = win[(pos - dist) & WinMask];
+ pos++;
+ if (pos == WinSize) {
+ s->pos = pos;
+ s->lenpos = len;
+ s->nclen = dist;
+ s->state = DecodeBlockCopy;
+ return FlateHasOutput;
+ }
}
- } else /* EOB: sym == 256 */
- return;
+ }
+ } else { /* EOB: sym == 256 */
+ s->pos = pos;
+ return FlateOk;
+ }
}
}
-/* TODO: untested, slow code */
-static void inflate_uncompressed_block(Stream *s) {
- uint len, invlen;
-
- /* start next block on a byte boundary */
- s->bits >>= s->nbits & 7;
- s->nbits &= ~7;
-
- len = getbits(s, 16);
- invlen = getbits(s, 16);
- /* s->nbits should be 0 here */
- if (s->error != FlateOk)
- return;
- if (len != (~invlen & 0xffff)) {
- s->error = FlateCorrupted;
- return;
- }
- while (len--) {
- s->win[s->pos++] = getbits(s, 8);
- checkpos(s);
- }
-}
-static void inflate_fixed_block(Stream *s) {
- /* lazy initialization of fixed huff code trees */
+/* extern */
+
+int inflate_init(State *s) {
+ /* TODO */
if (lhuff.nbits == 0)
init_fixed_huffs();
- decode_block(s, &lhuff, &dhuff);
-}
-
-static void inflate_dynamic_block(Stream *s) {
- decode_huffs(s);
- if (s->error != FlateOk)
- return;
- decode_block(s, &s->lhuff, &s->dhuff);
+ s->pos = 0;
+ s->win = malloc(WinSize);
+ if (!s->win)
+ return FlateError;
+ s->bits = 0;
+ s->nbits = 0;
+
+ s->state = BlockHead;
+ s->final = 0;
+ return FlateOk;
}
+/* inflate, returns: short src, short dst, error, ok */
+int inflate(State *s) {
+ int n;
-/* extern */
-
-/* inflate with callbacks */
-int inflate(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) {
- Stream s;
- uint final;
-
- s.r = r;
- s.rdata = rdata;
- s.w = w;
- s.wdata = wdata;
- s.srcbegin = s.srcend = s.src = malloc(SrcSize);
- s.win = malloc(WinSize);
- s.pos = 0;
- s.error = FlateOk;
- s.nbits = 0;
- do {
- uint blocktype;
-
- final = getbit(&s);
- blocktype = getbits(&s, 2);
- if (s.error != FlateOk)
- return 0;
-
- /* decompress block */
- switch (blocktype) {
- case 0:
- inflate_uncompressed_block(&s);
+ for (;;) {
+ switch (s->state) {
+ case BlockHead:
+ if (s->final) {
+ if (s->pos)
+ goto hasoutput;
+ else
+ goto finish;
+ }
+ if (!fillbits(s, 3))
+ goto needinput;
+ s->final = getbits(s, 1);
+ n = getbits(s, 2);
+ if (n == 0)
+ s->state = UncompressedBlock;
+ else if (n == 1)
+ s->state = FixedHuff;
+ else if (n == 2)
+ s->state = DynamicHuff;
+ else
+ goto error;
break;
- case 1:
- inflate_fixed_block(&s);
+ case UncompressedBlock:
+ /* start block on a byte boundary */
+ s->bits >>= s->nbits & 7;
+ s->nbits &= ~7;
+ if (!fillbits(s, 32))
+ goto needinput;
+ s->lenpos = getbits(s, 16);
+ n = getbits(s, 16);
+ if (s->lenpos != (~n & 0xffff))
+ goto error;
+ s->state = CopyUncompressed;
+ case CopyUncompressed:
+ /* TODO: untested, slow */
+ /* s->nbits should be 0 here */
+ while (s->lenpos) {
+ if (s->src == s->srcend)
+ goto needinput;
+ s->lenpos--;
+ s->win[s->pos++] = *s->src++;
+ if (s->pos == WinSize)
+ goto hasoutput;
+ }
+ s->state = BlockHead;
break;
- case 2:
- inflate_dynamic_block(&s);
+ case FixedHuff:
+ if (lhuff.nbits == 0) /* lazy init */
+ init_fixed_huffs();
+ s->fixed = 1;
+ s->state = DecodeBlock;
+ break;
+ case DynamicHuff:
+ /* decode dynamic huffman code trees */
+ if (!fillbits(s, 14))
+ goto needinput;
+ s->nlit = 257 + getbits(s, 5);
+ s->ndist = 1 + getbits(s, 5);
+ s->nclen = 4 + getbits(s, 4);
+ if (s->nlit > Nlitlen || s->ndist > Ndist)
+ goto error;
+ /* build code length tree */
+ for (n = 0; n < Nclen; n++)
+ s->lens[n] = 0;
+ s->fixed = 0;
+ s->state = DynamicHuffClen;
+ s->lenpos = 0;
+ case DynamicHuffClen:
+ for (n = s->lenpos; n < s->nclen; n++)
+ if (fillbits(s, 3)) {
+ s->lens[clenorder[n]] = getbits(s, 3);
+ } else {
+ s->lenpos = n;
+ goto needinput;
+ }
+ /* using lhuff for code length huff code */
+ if (build_huff(&s->lhuff, s->lens, Nclen, ClenTableBits) < 0)
+ goto error;
+ s->state = DynamicHuffLitlenDist;
+ s->lenpos = 0;
+ s->nclen = -1; /* decoded symbol is stored in clen or -1 */
+ case DynamicHuffLitlenDist:
+ if (s->nclen >= 0)
+ goto dynhuff_continue;
+ /* decode code lengths for the dynamic trees */
+ for (n = s->lenpos; n < s->nlit + s->ndist; ) {
+ uint sym = decode_symbol(s, &s->lhuff);
+ uint len;
+ uchar c;
+
+ if (sym < 16) {
+ s->lens[n++] = sym;
+ continue;
+ } else if (sym == (uint)FlateNeedInput) {
+ s->nclen = -1;
+ s->lenpos = n;
+ goto needinput;
+dynhuff_continue:
+ n = s->lenpos;
+ sym = s->nclen;
+ }
+ if (!fillbits(s, 7)) {
+ /* TODO: 7 is too much */
+ if (sym == (uint)FlateError)
+ goto error;
+ s->nclen = sym;
+ s->lenpos = n;
+ goto needinput;
+ }
+ if (sym == 16) {
+ /* copy previous code length 3-6 times */
+ c = s->lens[n - 1];
+ for (len = 3 + getbits(s, 2); len; len--)
+ s->lens[n++] = c;
+ } else if (sym == 17) {
+ /* repeat 0 for 3-10 times */
+ for (len = 3 + getbits(s, 3); len; len--)
+ s->lens[n++] = 0;
+ } else if (sym == 18) {
+ /* repeat 0 for 11-138 times */
+ for (len = 11 + getbits(s, 7); len; len--)
+ s->lens[n++] = 0;
+ } else
+ goto error;
+ }
+ /* build dynamic huffman code trees */
+ if (build_huff(&s->lhuff, s->lens, s->nlit, LitlenTableBits) < 0)
+ goto error;
+ if (build_huff(&s->dhuff, s->lens + s->nlit, s->ndist, DistTableBits) < 0)
+ goto error;
+ s->state = DecodeBlock;
+ case DecodeBlock:
+ case DecodeBlockLenBits:
+ case DecodeBlockDist:
+ case DecodeBlockDistBits:
+ case DecodeBlockCopy:
+ if (s->fixed)
+ n = decode_block(s, &lhuff, &dhuff);
+ else
+ n = decode_block(s, &s->lhuff, &s->dhuff);
+ if (n == FlateNeedInput)
+ goto needinput;
+ if (n == FlateHasOutput)
+ goto hasoutput;
+ if (n == FlateError)
+ goto error;
+ s->state = BlockHead;
break;
default:
- s.error = FlateCorrupted;
+ goto error;
}
- if (s.error != FlateOk)
- break;
- } while (!final);
- flush_win(&s);
- free(s.win);
- free(s.srcbegin);
- return s.error;
+ }
+needinput:
+ return FlateNeedInput;
+hasoutput:
+ return FlateHasOutput;
+error:
+ free(s->win);
+ return FlateError;
+finish:
+ free(s->win);
+ return FlateOk;
}
+int inflate_callback(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) {
+ State s;
+ uchar *src;
+ int len;
+
+ if (inflate_init(&s) != FlateOk)
+ return FlateError;
+ s.src = s.srcend = src = malloc(4096);
+ if (src == NULL)
+ return FlateError;
+ for (;;)
+ switch (inflate(&s)) {
+ case FlateNeedInput:
+ len = r(src, 4096, rdata);
+ if (len <= 0) {
+ free(src);
+ return FlateError;
+ }
+ s.src = src;
+ s.srcend = src + len;
+ break;
+ case FlateHasOutput:
+ len = w(s.win, s.pos, wdata);
+ if (len != s.pos) {
+ free(src);
+ return FlateError;
+ }
+ s.pos = 0; /* ouch */
+ break;
+ case FlateOk:
+ free(src);
+ return FlateOk;
+ case FlateError:
+ free(src);
+ return FlateError;
+ }
+}
/* simple test */
@@ -575,7 +700,8 @@ int main() {
in.n = 0;
out.f = stdout;
out.n = 0;
- err = inflate(r, &in, w, &out);
+ err = inflate_callback(r, &in, w, &out);
fprintf(stderr, "error: %d, decompressed %u bytes\n", err, out.n);
return 0;
}
+
diff --git a/inflate_callback.c b/inflate_callback.c
@@ -0,0 +1,590 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+typedef unsigned char uchar;
+typedef unsigned short ushort;
+typedef unsigned int uint;
+
+enum {
+ CodeBits = 16, /* max number of bits in a code + 1 */
+ LitlenTableBits = 9, /* litlen code bits used in lookup table */
+ DistTableBits = 6, /* dist code bits used in lookup table */
+ ClenTableBits = 6, /* clen code bits used in lookup table */
+ TableBits = LitlenTableBits, /* log2(lookup table size) */
+ Nlit = 256, /* number of lit codes */
+ Nlen = 29, /* number of len codes */
+ Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */
+ Ndist = 30, /* number of distance codes */
+ Nclen = 19, /* number of code length codes */
+ SrcSize = 1 << 12, /* input buffer size */
+ WinSize = 1 << 15, /* output window size */
+ WinMask = WinSize - 1 /* window pos (index) mask */
+};
+
+enum {
+ FlateOk = 0,
+ FlateInError = -1,
+ FlateOutError = -2,
+ FlateCorrupted = -3
+};
+
+typedef struct {
+ short len; /* code length */
+ ushort sym; /* symbol */
+} Entry;
+
+/* huffman code tree */
+typedef struct {
+ Entry table[1 << TableBits]; /* prefix lookup table */
+ uint nbits; /* prefix length (table size is 1 << nbits) */
+ uint sum; /* full codes in table: sum(count[0..nbits]) */
+ ushort count[CodeBits]; /* number of codes with given length */
+ ushort symbol[Nlitlen]; /* symbols ordered by code length (lexic.) */
+} Huff;
+
+typedef struct {
+ int (*r)(void *, int, void *);
+ void *rdata;
+ uchar *srcbegin;
+ uchar *srcend;
+ uchar *src; /* input buffer pointer */
+
+ uint bits;
+ uint nbits;
+
+ int (*w)(void *, int, void *);
+ void *wdata;
+ uchar *win; /* output window */
+ uint pos; /* window pos */
+
+ int error;
+
+ Huff lhuff; /* dynamic lit/len huffman code tree */
+ Huff dhuff; /* dynamic distance huffman code tree */
+} Stream;
+
+/* TODO: these globals are initialized in a lazy way (not thread safe) */
+static Huff lhuff; /* fixed lit/len huffman code tree */
+static Huff dhuff; /* fixed distance huffman code tree */
+
+/* base offset and extra bits tables */
+static uchar lenbits[Nlen] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
+};
+static ushort lenbase[Nlen] = {
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
+};
+static uchar distbits[Ndist] = {
+ 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
+};
+static ushort distbase[Ndist] = {
+ 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
+};
+
+/* ordering of code lengths */
+static uchar clenorder[Nclen] = {
+ 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
+};
+
+/* TODO: this or normal inc + reverse() */
+/* increment bitwise reversed n (msb is bit 0, lsb is bit len-1) */
+static uint revinc(uint n, uint len) {
+ uint i = 1 << (len - 1);
+
+ while (n & i)
+ i >>= 1;
+ if (i) {
+ n &= i - 1;
+ n |= i;
+ } else
+ n = 0;
+ return n;
+}
+
+/* build huffman code tree from code lengths (each should be < CodeBits) */
+static int build_huff(Huff *huff, uchar *lens, uint n, uint nbits) {
+ int offs[CodeBits];
+ int left;
+ uint i, c, sum, code, len, min, max;
+ ushort *count = huff->count;
+ ushort *symbol = huff->symbol;
+ Entry *table = huff->table;
+ Entry entry;
+
+ /* count code lengths */
+ for (i = 0; i < CodeBits; i++)
+ count[i] = 0;
+ for (i = 0; i < n; i++)
+ count[lens[i]]++;
+ if (count[0] == n)
+ return -1;
+ count[0] = 0;
+
+ /* bound code lengths, force nbits to be within the bounds */
+ for (max = CodeBits - 1; max > 0; max--)
+ if (count[max] != 0)
+ break;
+ if (nbits > max)
+ nbits = max;
+ for (min = 1; min < CodeBits; min++)
+ if (count[min] != 0)
+ break;
+ if (nbits < min)
+ nbits = min;
+ huff->nbits = nbits;
+
+ /* check if length is over-subscribed or incomplete */
+ for (left = 1 << min, i = min; i <= max; left <<= 1, i++) {
+ left -= count[i];
+ /* left < 0: over-subscribed, left > 0: incomplete */
+ if (left < 0)
+ return -1;
+ }
+
+ for (sum = 0, i = 0; i <= max; i++) {
+ offs[i] = sum;
+ sum += count[i];
+ }
+ /* needed for decoding codes longer than nbits */
+ if (nbits < max)
+ huff->sum = offs[nbits + 1];
+
+ /* sort symbols by code length (lexicographic order) */
+ for (i = 0; i < n; i++)
+ if (lens[i])
+ symbol[offs[lens[i]]++] = i;
+
+ /* lookup table for decoding nbits from input.. */
+ for (i = 0; i < 1 << nbits; i++)
+ table[i].len = 0; /* invalid marker for incomplete code */
+ code = 0;
+ /* ..if code is at most nbits (bits are in reverse order, sigh..) */
+ for (len = min; len <= nbits; len++)
+ for (c = count[len]; c > 0; c--) {
+ entry.len = len;
+ entry.sym = *symbol;
+ for (i = code; i < 1 << nbits; i += 1 << len)
+ table[i] = entry;
+ /* next code */
+ symbol++;
+ code = revinc(code, len);
+ }
+ /* ..if code is longer than nbits: values for simple bitwise decode */
+ for (i = 0; code; i++) {
+ table[code].len = -1;
+ table[code].sym = i << 1;
+ code = revinc(code, nbits);
+ }
+ return 0;
+}
+
+/* fixed huffman code trees (should be done at compile time..) */
+static void init_fixed_huffs(void) {
+ int i;
+ uchar lens[Nlitlen];
+
+ for (i = 0; i < 144; i++)
+ lens[i] = 8;
+ for (; i < 256; i++)
+ lens[i] = 9;
+ for (; i < 280; i++)
+ lens[i] = 7;
+ for (; i < Nlitlen; i++)
+ lens[i] = 8;
+ build_huff(&lhuff, lens, Nlitlen, 8);
+
+ for (i = 0; i < Ndist; i++)
+ lens[i] = 5;
+ build_huff(&dhuff, lens, Ndist, 5);
+}
+
+/* check src before accessing it */
+static int checksrc(Stream *s) {
+ if (s->src == s->srcend) {
+ s->src = s->srcend = s->srcbegin;
+ s->srcend += s->r(s->src, SrcSize, s->rdata);
+ if (s->src >= s->srcend) {
+ s->error = FlateInError;
+ return 0;
+ }
+ }
+ return 1;
+}
+
+/* check window position before writing to it */
+static int checkpos(Stream *s) {
+ if (s->pos == WinSize) {
+ s->pos = 0;
+ if (s->w(s->win, WinSize, s->wdata) != WinSize) {
+ s->error = FlateOutError;
+ return 0;
+ }
+ }
+ return 1;
+}
+
+/* flush output window */
+static void flush_win(Stream *s) {
+ if (s->pos)
+ if (s->w(s->win, s->pos, s->wdata) != s->pos)
+ s->error = FlateOutError;
+}
+
+/* get one bit from stream */
+static uint getbit(Stream *s) {
+ uint bit;
+
+ if (!s->nbits--) {
+ if (!checksrc(s))
+ return 0;
+ s->bits = *s->src++;
+ s->nbits = 7;
+ }
+ bit = s->bits & 1;
+ s->bits >>= 1;
+ return bit;
+}
+
+/* get n bits from stream */
+static uint getbits(Stream *s, int n) {
+ uint bits;
+ uint nbits;
+
+ if (n == 0)
+ return 0;
+ bits = s->bits;
+ nbits = s->nbits;
+ while (nbits < n) {
+ if (!checksrc(s))
+ return 0;
+ bits |= (*s->src++ << nbits);
+ nbits += 8;
+ }
+ s->bits = bits >> n;
+ s->nbits = nbits - n;
+ return bits & ((1 << n) - 1);
+}
+
+/* decode a symbol from stream with huff code */
+static uint decode_symbol(Stream *s, Huff *huff) {
+ uint huffbits = huff->nbits;
+ uint nbits = s->nbits;
+ uint bits = s->bits;
+ uint mask = (1 << huffbits) - 1;
+ Entry entry;
+
+ /* get enough bits efficiently */
+ if (nbits < huffbits) {
+ if (s->src + 2 < s->srcend) {
+ /* we assume huffbits <= 9 */
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ } else
+ do {
+ if (!checksrc(s)) {
+ /* EOB might be < huffbits */
+ entry = huff->table[bits & mask];
+ if (entry.len > 0 && entry.len <= nbits) {
+ s->bits = bits >> entry.len;
+ s->nbits = nbits - entry.len;
+ s->error = FlateOk;
+ return entry.sym;
+ }
+ return 0;
+ }
+ bits |= *s->src++ << nbits;
+ nbits += 8;
+ } while (nbits < huffbits);
+ }
+ entry = huff->table[bits & mask];
+ if (entry.len > 0) {
+ s->bits = bits >> entry.len;
+ s->nbits = nbits - entry.len;
+ return entry.sym;
+ } else if (entry.len == 0) {
+ s->error = FlateCorrupted;
+ return 0;
+ }
+ s->bits = bits >> huffbits;
+ s->nbits = nbits - huffbits;
+ /* code is longer than huffbits: bitwise decode the rest */
+ {
+ int cur = entry.sym;
+ int sum = huff->sum;
+ /* TODO: count[0..huffbits] is never needed */
+ ushort *count = huff->count + huffbits + 1;
+
+ for (;;) {
+ cur |= getbit(s);
+ sum += *count;
+ cur -= *count;
+ if (cur < 0)
+ break;
+ cur <<= 1;
+ count++;
+ if (count == huff->count + CodeBits) {
+ s->error = FlateCorrupted;
+ return 0;
+ }
+ }
+ return huff->symbol[sum + cur];
+ }
+}
+
+/* decode dynamic huffman code trees from stream */
+static void decode_huffs(Stream *s) {
+ Huff clhuff;
+ uchar lens[Nlitlen+Ndist];
+ uint nlit, ndist, nclen;
+ uint i;
+
+ nlit = 257 + getbits(s, 5);
+ ndist = 1 + getbits(s, 5);
+ nclen = 4 + getbits(s, 4);
+ if (s->error != FlateOk)
+ return;
+ if (nlit > Nlitlen || ndist > Ndist) {
+ s->error = FlateCorrupted;
+ return;
+ }
+
+ /* build code length tree */
+ for (i = 0; i < Nclen; i++)
+ lens[i] = 0;
+ for (i = 0; i < nclen; i++)
+ lens[clenorder[i]] = getbits(s, 3);
+ if (build_huff(&clhuff, lens, Nclen, ClenTableBits) < 0) {
+ s->error = FlateCorrupted;
+ return;
+ }
+
+ /* decode code lengths for the dynamic trees */
+ for (i = 0; i < nlit + ndist; ) {
+ uint sym = decode_symbol(s, &clhuff);
+ uint len;
+ uchar c;
+
+ if (sym < 16) {
+ lens[i++] = sym;
+ } else if (sym == 16) {
+ /* copy previous code length 3-6 times */
+ c = lens[i - 1];
+ for (len = 3 + getbits(s, 2); len; len--)
+ lens[i++] = c;
+ } else if (sym == 17) {
+ /* repeat 0 for 3-10 times */
+ for (len = 3 + getbits(s, 3); len; len--)
+ lens[i++] = 0;
+ } else if (sym == 18) {
+ /* repeat 0 for 11-138 times */
+ for (len = 11 + getbits(s, 7); len; len--)
+ lens[i++] = 0;
+ } else
+ s->error = FlateCorrupted;
+ if (s->error != FlateOk)
+ return;
+ }
+ /* build dynamic huffman code trees */
+ if (build_huff(&s->lhuff, lens, nlit, LitlenTableBits) < 0)
+ s->error = FlateCorrupted;
+ if (build_huff(&s->dhuff, lens + nlit, ndist, DistTableBits) < 0)
+ s->error = FlateCorrupted;
+}
+
+/* decode a block of data from stream with trees */
+static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) {
+ for (;;) {
+ uint sym = decode_symbol(s, lhuff);
+
+ if (s->error != FlateOk)
+ return;
+ if (sym < 256) {
+ s->win[s->pos++] = sym;
+ if (!checkpos(s))
+ return;
+ } else if (sym > 256) {
+ uint len, dist;
+
+ sym -= 257;
+ if (sym >= Nlen) {
+ s->error = FlateCorrupted;
+ return;
+ }
+ len = lenbase[sym] + getbits(s, lenbits[sym]);
+ sym = decode_symbol(s, dhuff);
+ if (s->error != FlateOk)
+ return;
+ if (sym >= Ndist) {
+ s->error = FlateCorrupted;
+ return;
+ }
+ dist = distbase[sym] + getbits(s, distbits[sym]);
+ if (s->error != FlateOk)
+ return;
+ /* copy match, loop unroll in common case */
+ if (s->pos + len <= WinSize) {
+ uint pos = s->pos;
+
+ /* lenbase[sym] >= 3 */
+ do {
+ s->win[pos] = s->win[(pos - dist) & WinMask];
+ pos++;
+ s->win[pos] = s->win[(pos - dist) & WinMask];
+ pos++;
+ s->win[pos] = s->win[(pos - dist) & WinMask];
+ pos++;
+ len -= 3;
+ } while (len >= 3);
+ if (len--) {
+ s->win[pos] = s->win[(pos - dist) & WinMask];
+ pos++;
+ if (len) {
+ s->win[pos] = s->win[(pos - dist) & WinMask];
+ pos++;
+ }
+ }
+ s->pos = pos;
+ if (!checkpos(s))
+ return;
+ } else
+ while (len--) {
+ s->win[s->pos] = s->win[(s->pos - dist) & WinMask];
+ s->pos++;
+ if (!checkpos(s))
+ return;
+ }
+ } else /* EOB: sym == 256 */
+ return;
+ }
+}
+
+/* TODO: untested, slow code */
+static void inflate_uncompressed_block(Stream *s) {
+ uint len, invlen;
+
+ /* start next block on a byte boundary */
+ s->bits >>= s->nbits & 7;
+ s->nbits &= ~7;
+
+ len = getbits(s, 16);
+ invlen = getbits(s, 16);
+ /* s->nbits should be 0 here */
+ if (s->error != FlateOk)
+ return;
+ if (len != (~invlen & 0xffff)) {
+ s->error = FlateCorrupted;
+ return;
+ }
+ while (len--) {
+ s->win[s->pos++] = getbits(s, 8);
+ checkpos(s);
+ }
+}
+
+static void inflate_fixed_block(Stream *s) {
+ /* lazy initialization of fixed huff code trees */
+ if (lhuff.nbits == 0)
+ init_fixed_huffs();
+ decode_block(s, &lhuff, &dhuff);
+}
+
+static void inflate_dynamic_block(Stream *s) {
+ decode_huffs(s);
+ if (s->error != FlateOk)
+ return;
+ decode_block(s, &s->lhuff, &s->dhuff);
+}
+
+
+/* extern */
+
+/* inflate with callbacks */
+int inflate(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) {
+ Stream s;
+ uint final;
+
+ s.r = r;
+ s.rdata = rdata;
+ s.w = w;
+ s.wdata = wdata;
+ s.srcbegin = s.srcend = s.src = malloc(SrcSize);
+ s.win = malloc(WinSize);
+ s.pos = 0;
+ s.error = FlateOk;
+ s.nbits = 0;
+ do {
+ uint blocktype;
+
+ final = getbit(&s);
+ blocktype = getbits(&s, 2);
+ if (s.error != FlateOk)
+ return 0;
+
+ /* decompress block */
+ switch (blocktype) {
+ case 0:
+ inflate_uncompressed_block(&s);
+ break;
+ case 1:
+ inflate_fixed_block(&s);
+ break;
+ case 2:
+ inflate_dynamic_block(&s);
+ break;
+ default:
+ s.error = FlateCorrupted;
+ }
+ if (s.error != FlateOk)
+ break;
+ } while (!final);
+ flush_win(&s);
+ free(s.win);
+ free(s.srcbegin);
+ return s.error;
+}
+
+
+/* simple test */
+
+#include <stdio.h>
+
+struct data {
+ FILE *f;
+ uint n;
+};
+
+int r(void *p, int siz, void *data) {
+ uint n;
+ struct data *d = data;
+
+ n = fread(p, 1, siz, d->f);
+ d->n += n;
+ return n;
+}
+
+int w(void *p, int siz, void *data) {
+ uint n;
+ struct data *d = data;
+
+ n = fwrite(p, 1, siz, d->f);
+ d->n += n;
+ return n;
+}
+
+int main() {
+ struct data in, out;
+ int err;
+
+ in.f = stdin;
+ in.n = 0;
+ out.f = stdout;
+ out.n = 0;
+ err = inflate(r, &in, w, &out);
+ fprintf(stderr, "error: %d, decompressed %u bytes\n", err, out.n);
+ return 0;
+}