flate

deflate implementation
git clone git://git.suckless.org/flate
Log | Files | Refs | README

commit c289ca3da482ab0d0e8bca34550583fc39e5d41c
parent 746ffd95970aa1019058f0baba125bfcc3a0fa02
Author: nsz <nszabolcs@gmail.com>
Date:   Tue, 21 Apr 2009 23:14:30 +0200

full validation (bogus)
Diffstat:
inflate.c | 204++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------
1 file changed, 140 insertions(+), 64 deletions(-)

diff --git a/inflate.c b/inflate.c @@ -12,29 +12,41 @@ TODO: typedef unsigned char uchar; typedef unsigned short ushort; typedef unsigned int uint; -typedef unsigned long ulong; enum { - HuffBits = 16, /* max number of bits in a code */ - Nlit = 256, /* number of lit codes */ - Nlen = 29, /* number of len codes */ - Nlitlen = Nlit + Nlen + 3, /* number of litlen codes + block end + 2 unused */ - Ndist = 30, /* number of distance codes */ - Nclen = 19 /* number of code length codes */ + HuffBits = 16, /* max number of bits in a code */ + Nlit = 256, /* number of lit codes */ + Nlen = 29, /* number of len codes */ + Nlitlen = Nlit + Nlen + 3, /* number of litlen codes + block end + 2 unused */ + Ndist = 30, /* number of distance codes */ + Nclen = 19 /* number of code length codes */ }; +enum { + FlateOk = 0, + FlateShortSrc = -2, + FlateShortDst = -3, + FlateCorrupted = -4 +}; + + typedef struct { ushort count[HuffBits]; /* code bit length -> count */ ushort symbol[Nlitlen]; /* symbols ordered by code length */ } HuffTree; typedef struct { - const uchar *src; + uchar *src; + uchar *srcend; + + uchar *dst; + uchar *dstbegin; + uchar *dstend; + uint bits; uint nbits; - uchar *dst; - uint *dstlen; + int error; HuffTree ltree; /* dynamic lit/len tree */ HuffTree dtree; /* dynamic distance tree */ @@ -61,12 +73,13 @@ static uchar distbits[Ndist] = { static ushort distbase[Ndist]; /* ordering of code lengths */ -static const uchar clenorder[Nclen] = { +static uchar clenorder[Nclen] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; static void init_base_tables(void) { - int i, base; + uint base; + int i; for (base = 3, i = 0; i < Nlen; i++) { lenbase[i] = base; @@ -101,9 +114,9 @@ static void init_fixed_trees(void) { } /* build huffman code tree from code lengths */ -static void build_tree(HuffTree *t, const uchar *lens, uint n) { +static int build_tree(HuffTree *t, const uchar *lens, int n) { int offs[HuffBits]; - int i, sum; + int i, sum, left; for (i = 0; i < HuffBits; i++) t->count[i] = 0; @@ -111,7 +124,19 @@ static void build_tree(HuffTree *t, const uchar *lens, uint n) { /* count code lengths and calc first code for each length */ for (i = 0; i < n; i++) t->count[lens[i]]++; + if (t->count[0] == n) + return -1; t->count[0] = 0; + + /* check if length is over-subscribed or incomplete */ + for (left = i = 1; i <= HuffBits; i++) { + left <<= 1; + left -= t->count[i]; + /* left < 0: over-subscribed, left > 0: incomplete */ + if (left < 0) + return -1; + } + for (sum = 0, i = 0; i < HuffBits; i++) { offs[i] = sum; sum += t->count[i]; @@ -121,6 +146,7 @@ static void build_tree(HuffTree *t, const uchar *lens, uint n) { for (i = 0; i < n; i++) if (lens[i]) t->symbol[offs[lens[i]]++] = i; + return 0; } /* get one bit from s->src stream */ @@ -128,6 +154,10 @@ static uint getbit(FlateStream *s) { uint bit; if (!s->nbits--) { + if (s->src == s->srcend) { + s->error = FlateShortSrc; + return 0; + } s->bits = *s->src++; s->nbits = 7; } @@ -168,12 +198,18 @@ static uint decode_symbol(FlateStream *s, HuffTree *t) { cur <<= 1; count++; } + if (s->src == s->srcend) { + s->error = FlateShortSrc; + return 0; + } bits = *s->src++; nbits = 8; } found: - if (count >= t->count + HuffBits) - /* error */; + if (count >= t->count + HuffBits) { + s->error = FlateCorrupted; + return 0; + } s->bits = bits; s->nbits = nbits; return t->symbol[sum + cur]; @@ -189,15 +225,22 @@ static void decode_trees(FlateStream *s, HuffTree *lt, HuffTree *dt) { nlit = 257 + getbits(s, 5); ndist = 1 + getbits(s, 5); nclen = 4 + getbits(s, 4); - if (nlit > Nlitlen || ndist > Ndist) - /* error */; + if (s->error != FlateOk) + return; + if (nlit > Nlitlen || ndist > Ndist) { + s->error = FlateCorrupted; + return; + } /* build code length tree */ for (i = 0; i < Nclen; i++) lens[i] = 0; for (i = 0; i < nclen; i++) lens[clenorder[i]] = getbits(s, 3); - build_tree(&ctree, lens, Nclen); + if (build_tree(&ctree, lens, Nclen) < 0) { + s->error = FlateCorrupted; + return; + } /* decode code lengths for the dynamic trees */ for (i = 0; i < nlit + ndist; ) { @@ -221,40 +264,62 @@ static void decode_trees(FlateStream *s, HuffTree *lt, HuffTree *dt) { for (len = 11 + getbits(s, 7); len; len--) lens[i++] = 0; } else - /* error */; + s->error = FlateCorrupted; + if (s->error != FlateOk) + return; } /* build dynamic huffman trees */ - build_tree(lt, lens, nlit); - build_tree(dt, lens + nlit, ndist); + if (build_tree(lt, lens, nlit) < 0) + s->error = FlateCorrupted; + if (build_tree(dt, lens + nlit, ndist) < 0) + s->error = FlateCorrupted; } /* decode a block of data from stream with trees */ -static int decode_block(FlateStream *s, HuffTree *lt, HuffTree *dt) { - uchar *start = s->dst; +static void decode_block(FlateStream *s, HuffTree *lt, HuffTree *dt) { uint sym; for (;;) { sym = decode_symbol(s, lt); - if (sym == 256) { - *s->dstlen += s->dst - start; - return 0; - } - if (sym < 256) + if (s->error != FlateOk) + return; + if (sym == 256) + return; + if (sym < 256) { + if (s->dst == s->dstend) { + s->error = FlateShortDst; + return; + } *s->dst++ = sym; - else { - uint len, dist; - uint i; + } else { + int i, len, dist; sym -= 257; - if (sym >= Nlen) - /* error */; + if (sym >= Nlen) { + s->error = FlateCorrupted; + return; + } len = lenbase[sym] + getbits(s, lenbits[sym]); sym = decode_symbol(s, dt); - if (sym >= Ndist) - /* error */; + if (s->error != FlateOk) + return; + if (sym >= Ndist) { + s->error = FlateCorrupted; + return; + } dist = distbase[sym] + getbits(s, distbits[sym]); + if (s->error != FlateOk) + return; /* copy match */ + if (s->dst + len > s->dstend) { + s->error = FlateShortDst; + return; + } + if (s->dst - dist < s->dstbegin) { + s->error = FlateCorrupted; + return; + } for (i = 0; i < len; i++) s->dst[i] = s->dst[i - dist]; s->dst += len; @@ -262,31 +327,41 @@ static int decode_block(FlateStream *s, HuffTree *lt, HuffTree *dt) { } } -static int inflate_uncompressed_block(FlateStream *s) { +static void inflate_uncompressed_block(FlateStream *s) { uint len, invlen; len = (s->src[1] << 8) | s->src[0]; invlen = (s->src[3] << 8) | s->src[2]; s->src += 4; - if (len != (~invlen & 0x0000ffff)) - /* error */; + if (len != (~invlen & 0x0000ffff)) { + s->error = FlateCorrupted; + return; + } /* start next block on a byte boundary */ s->nbits = 0; - *s->dstlen += len; + if (s->dst + len > s->dstend) { + s->error = FlateShortDst; + return; + } + if (s->src + len > s->srcend) { + s->error = FlateShortSrc; + return; + } while (len--) *s->dst++ = *s->src++; - return 0; } -static int inflate_fixed_block(FlateStream *s) { - return decode_block(s, &ltree, &dtree); +static void inflate_fixed_block(FlateStream *s) { + decode_block(s, &ltree, &dtree); } -static int inflate_dynamic_block(FlateStream *s) { +static void inflate_dynamic_block(FlateStream *s) { decode_trees(s, &s->ltree, &s->dtree); - return decode_block(s, &s->ltree, &s->dtree); + if (s->error != FlateOk) + return; + decode_block(s, &s->ltree, &s->dtree); } @@ -298,47 +373,49 @@ void inflate_init(void) { init_base_tables(); } +#include <stdlib.h> +#include <stdio.h> + /* inflate stream from src to dst */ -int inflate(void *dst, uint *dstlen, const void *src, uint srclen) { +uint inflate(void *dst, uint dstlen, void *src, uint srclen) { FlateStream s; uint final; s.src = src; + s.srcend = s.src + srclen; + s.dstbegin = s.dst = dst; + s.dstend = s.dst + dstlen; + s.error = FlateOk; s.nbits = 0; - s.dst = dst; - s.dstlen = dstlen; - *dstlen = 0; do { uint blocktype; - int res; final = getbit(&s); blocktype = getbits(&s, 2); + if (s.error != FlateOk) + return fprintf(stderr, "error: %d\n", s.error); /* decompress block */ switch (blocktype) { case 0: - res = inflate_uncompressed_block(&s); + inflate_uncompressed_block(&s); break; case 1: - res = inflate_fixed_block(&s); + inflate_fixed_block(&s); break; case 2: - res = inflate_dynamic_block(&s); + inflate_dynamic_block(&s); break; default: - return -1; + s.error = FlateCorrupted; } - if (res != 0) - return -1; + if (s.error != FlateOk) + return fprintf(stderr, "error: %d\n", s.error); } while (!final); - return 0; + return s.dst - s.dstbegin; } -#include <stdio.h> -#include <stdlib.h> - void *readall(char *name, uint *len) { uint size = 1 << 22; void *buf; @@ -354,7 +431,6 @@ void *readall(char *name, uint *len) { } int main(int argc, char **argv) { - int ret; uchar *src, *dst; uint srclen, dstlen=1<<22; @@ -362,13 +438,13 @@ int main(int argc, char **argv) { return -1; src = readall(argv[1], &srclen); inflate_init(); - ret = inflate(dst = malloc(dstlen), &dstlen, src, srclen); - if (ret) + dstlen = inflate(dst = malloc(dstlen), dstlen, src, srclen); + if (dstlen == 0) fputs("inflate: error\n", stderr); else fprintf(stderr, "inflate: uncompressed %u bytes\n", dstlen); fwrite(dst, 1, dstlen, stdout); free(dst); free(src); - return ret; + return 0; }