flate

deflate implementation
git clone git://git.suckless.org/flate
Log | Files | Refs | README

commit b4ac11e701ddfb012b25d6f77f75ab8aa4db405e
parent f1d863cbfb093d47c811a5b2ea668c7990e91645
Author: nsz <nszabolcs@gmail.com>
Date:   Sat,  1 Aug 2009 19:15:39 +0200

inflate dynamic decode state fix, state->err, nice interface
Diffstat:
Makefile | 7++++---
inflate.c | 253++++++++++++++++++++++++++++++++++++++-----------------------------------------
2 files changed, 125 insertions(+), 135 deletions(-)

diff --git a/Makefile b/Makefile @@ -1,12 +1,13 @@ #CFLAGS=-g -Wall -ansi -pedantic CFLAGS=-O3 -Wall -ansi -pedantic LDFLAGS= -SRC=inflate.c inflate_simple.c inflate_callback.c deflate.c deflate_simple.c +SRC=inflate.c inflate_simple.c inflate_example.c inflate_callback.c \ + deflate.c deflate_simple.c OBJ=${SRC:.c=.o} -EXE=${SRC:.c=} +EXE=inflate inflate_simple deflate deflate_simple all: ${EXE} -inflate: inflate.o +inflate: inflate.o inflate_example.o ${CC} -o $@ $^ ${LDFLAGS} inflate_simple: inflate_simple.o ${CC} -o $@ $^ ${LDFLAGS} diff --git a/inflate.c b/inflate.c @@ -1,8 +1,6 @@ #include <stdlib.h> - -typedef unsigned char uchar; -typedef unsigned short ushort; -typedef unsigned int uint; +#include <string.h> +#include "inflate.h" enum { CodeBits = 16, /* max number of bits in a code + 1 */ @@ -15,18 +13,9 @@ enum { Nlitlen = Nlit+Nlen+3, /* litlen codes + block end + 2 unused */ Ndist = 30, /* number of distance codes */ Nclen = 19, /* number of code length codes */ - SrcSize = 1 << 12, /* input buffer size */ WinSize = 1 << 15 /* output window size */ }; -/* return values */ -enum { - FlateOk = 0, - FlateError = -1, - FlateNeedInput = -2, - FlateHasOutput = -3 -}; - /* states */ enum { BlockHead, @@ -59,28 +48,30 @@ typedef struct { } Huff; typedef struct { - uchar *src; /* input buffer pointer */ + uchar *src; /* input buffer pointer */ uchar *srcend; uint bits; uint nbits; - uchar *win; /* output window */ - uint pos; /* window pos */ + uchar win[WinSize]; /* output window */ + uint pos; /* window pos */ + uint posout; /* used for flushing win */ + + int state; /* decode state */ + int final; /* last block flag */ + char *err; /* error message */ - int state; /* decode state */ - int final; /* last block flag */ - /* for decoding dynamic code trees in inflate() */ int nlit; int ndist; - int nclen; /* also used in decode_block() */ - int lenpos; /* also used in decode_block() */ + int nclen; /* also used in decode_block() */ + int lenpos; /* also used in decode_block() */ uchar lens[Nlitlen + Ndist]; - int fixed; /* fixed code tree flag */ - Huff lhuff; /* dynamic lit/len huffman code tree */ - Huff dhuff; /* dynamic distance huffman code tree */ + int fixed; /* fixed code tree flag */ + Huff lhuff; /* dynamic lit/len huffman code tree */ + Huff dhuff; /* dynamic distance huffman code tree */ } State; /* TODO: globals.. initialization is not thread safe */ @@ -285,7 +276,7 @@ static uint decode_symbol_long(State *s, Huff *huff, uint bits, uint nbits, int cur <<= 1; count++; if (count == huff->count + CodeBits) - return FlateError; + return s->err = "symbol decoding failed.", FlateError; } s->bits = bits; s->nbits = nbits; @@ -337,7 +328,7 @@ static uint decode_symbol(State *s, Huff *huff) { s->nbits = nbits - entry.len; return entry.sym; } else if (entry.len == 0) - return FlateError; + return s->err = "symbol decoding failed.", FlateError; return decode_symbol_long(s, huff, bits, nbits, entry.sym); } @@ -356,7 +347,7 @@ static int decode_block(State *s, Huff *lhuff, Huff *dhuff) { if (sym < 256) { win[pos++] = sym; if (pos == WinSize) { - s->pos = pos; + s->pos = WinSize; s->state = DecodeBlock; return FlateHasOutput; } @@ -422,7 +413,7 @@ static int decode_block(State *s, Huff *lhuff, Huff *dhuff) { win[pos] = win[(pos - dist) % WinSize]; pos++; if (pos == WinSize) { - s->pos = pos; + s->pos = WinSize; s->lenpos = len; s->nclen = dist; /* using nclen to store dist */ s->state = DecodeBlockCopy; @@ -436,42 +427,23 @@ static int decode_block(State *s, Huff *lhuff, Huff *dhuff) { } } /* for (;;) */ } /* switch () */ - return FlateError; -} - - -/* extern */ - -int inflate_init(State *s) { - /* TODO */ - if (lhuff.nbits == 0) - init_fixed_huffs(); - s->pos = 0; - s->win = malloc(WinSize); - if (!s->win) - return FlateError; - s->bits = 0; - s->nbits = 0; - - s->state = BlockHead; - s->final = 0; - return FlateOk; + return s->err = "corrupted state.", FlateError; } -/* inflate, returns: short src, short dst, error, ok */ -int inflate(State *s) { +/* inflate state machine (decodes s->src into s->win) */ +static int inflate_state(State *s) { int n; + if (s->posout) + return FlateHasOutput; for (;;) { switch (s->state) { case BlockHead: if (s->final) { if (s->pos) return FlateHasOutput; - else { - free(s->win); + else return FlateOk; - } } if (!fillbits(s, 3)) return FlateNeedInput; @@ -484,7 +456,7 @@ int inflate(State *s) { else if (n == 2) s->state = DynamicHuff; else - goto error; + return s->err = "corrupt block header.", FlateError; break; case UncompressedBlock: /* start block on a byte boundary */ @@ -495,10 +467,10 @@ int inflate(State *s) { s->lenpos = getbits(s, 16); n = getbits(s, 16); if (s->lenpos != (~n & 0xffff)) - goto error; + return s->err = "corrupt uncompressed length.", FlateError; s->state = CopyUncompressed; case CopyUncompressed: - /* TODO: untested, slow */ + /* TODO: untested, slow, memcpy etc */ /* s->nbits should be 0 here */ while (s->lenpos) { if (s->src == s->srcend) @@ -522,7 +494,7 @@ int inflate(State *s) { s->ndist = 1 + getbits(s, 5); s->nclen = 4 + getbits(s, 4); if (s->nlit > Nlitlen || s->ndist > Ndist) - goto error; + return s->err = "corrupt code tree.", FlateError; /* build code length tree */ for (n = 0; n < Nclen; n++) s->lens[n] = 0; @@ -539,7 +511,7 @@ int inflate(State *s) { } /* using lhuff for code length huff code */ if (build_huff(&s->lhuff, s->lens, Nclen, ClenTableBits) < 0) - goto error; + return s->err = "building clen tree failed.", FlateError; s->state = DynamicHuffLitlenDist; s->lenpos = 0; case DynamicHuffLitlenDist: @@ -558,17 +530,17 @@ int inflate(State *s) { case DynamicHuffContinue: n = s->lenpos; sym = s->nclen; + s->state = DynamicHuffLitlenDist; } if (!fillbits(s, 7)) { - /* TODO: 7 is too much */ + /* TODO: 7 is too much when an almost empty block is at the end */ if (sym == (uint)FlateError) - goto error; + return FlateError; s->nclen = sym; s->lenpos = n; s->state = DynamicHuffContinue; return FlateNeedInput; } - /* TODO: bound check s->lens */ if (sym == 16) { /* copy previous code length 3-6 times */ @@ -584,13 +556,13 @@ int inflate(State *s) { for (len = 11 + getbits(s, 7); len; len--) s->lens[n++] = 0; } else - goto error; + return s->err = "corrupt code tree.", FlateError; } /* build dynamic huffman code trees */ if (build_huff(&s->lhuff, s->lens, s->nlit, LitlenTableBits) < 0) - goto error; + return s->err = "building litlen tree failed.", FlateError; if (build_huff(&s->dhuff, s->lens + s->nlit, s->ndist, DistTableBits) < 0) - goto error; + return s->err = "building dist tree failed.", FlateError; s->state = DecodeBlock; case DecodeBlock: case DecodeBlockLenBits: @@ -603,93 +575,110 @@ int inflate(State *s) { if (n == FlateHasOutput) return FlateHasOutput; if (n == FlateError) - goto error; + return FlateError; s->state = BlockHead; break; default: - goto error; + return s->err = "corrupt state.", FlateError; + } + } +} + +static State *alloc_state(void) { + State *s = malloc(sizeof(State)); + + if (s) { + s->final = s->pos = s->posout = s->bits = s->nbits = 0; + s->state = BlockHead; + s->src = s->srcend = 0; + s->err = 0; + /* TODO: globals.. */ + if (lhuff.nbits == 0) + init_fixed_huffs(); + } + return s; +} + + +/* extern */ + +int inflate(FlateStream *stream) { + State *s = stream->state; + int n; + + if (stream->err) { + if (s) { + free(s); + stream->state = 0; } + return FlateError; + } + if (!s) { + s = stream->state = alloc_state(); + if (!s) + return stream->err = "no mem.", FlateError; + } + if (stream->nin) { + s->src = stream->in; + s->srcend = s->src + stream->nin; + stream->nin = 0; + } + n = inflate_state(s); + if (n == FlateHasOutput) { + if (s->pos < stream->nout) + stream->nout = s->pos; + memcpy(stream->out, s->win + s->posout, stream->nout); + s->pos -= stream->nout; + if (s->pos) + s->posout += stream->nout; + else + s->posout = 0; } -error: - free(s->win); - return FlateError; + if (n == FlateOk || n == FlateError) { + stream->err = s->err; + free(s); + stream->state = 0; + } + return n; } int inflate_callback(int (*r)(void *, int, void *), void *rdata, int (*w)(void *, int, void *), void *wdata) { - State s; + State *s; uchar *src; - int len; + int len, n; + enum {SrcSize = 4096}; - if (inflate_init(&s) != FlateOk) + s = alloc_state(); + if (!s) return FlateError; - s.src = s.srcend = src = malloc(4096); - if (src == NULL) + s->src = s->srcend = src = malloc(SrcSize); + if (!src) { + free(s); return FlateError; + } + n = FlateNeedInput; for (;;) - switch (inflate(&s)) { + switch (n) { case FlateNeedInput: - len = r(src, 4096, rdata); - if (len <= 0) { - free(src); - return FlateError; - } - s.src = src; - s.srcend = src + len; + len = r(src, SrcSize, rdata); + if (len > 0) { + s->srcend = src + len; + n = inflate_state(s); + } else + n = FlateError; break; case FlateHasOutput: - len = w(s.win, s.pos, wdata); - if (len != s.pos) { - free(src); - return FlateError; - } - s.pos = 0; /* TODO: ouch */ + len = w(s->win, s->pos, wdata); + if (len == s->pos) { + s->pos = 0; + n = inflate_state(s); + } else + n = FlateError; break; case FlateOk: - free(src); - return FlateOk; case FlateError: free(src); - return FlateError; + free(s); + return n; } } - -/* simple test */ - -#include <stdio.h> - -struct data { - FILE *f; - uint n; -}; - -int r(void *p, int siz, void *data) { - uint n; - struct data *d = data; - - n = fread(p, 1, siz, d->f); - d->n += n; - return n; -} - -int w(void *p, int siz, void *data) { - uint n; - struct data *d = data; - - n = fwrite(p, 1, siz, d->f); - d->n += n; - return n; -} - -int main() { - struct data in, out; - int err; - - in.f = stdin; - in.n = 0; - out.f = stdout; - out.n = 0; - err = inflate_callback(r, &in, w, &out); - fprintf(stderr, "error: %d, decompressed %u bytes\n", err, out.n); - return 0; -} -