flate

deflate implementation
git clone git://git.suckless.org/flate
Log | Files | Refs | README

inflate_simple.c (6453B)


      1 /* no input validation, no bounds check, 2-3x slower than optimized inflate */
      2 
      3 typedef unsigned char uchar;
      4 typedef unsigned short ushort;
      5 typedef unsigned int uint;
      6 
      7 enum {
      8 	CodeBits  = 16,  /* max number of bits in a code + 1 */
      9 	Nlit      = 256, /* number of lit codes */
     10 	Nlen      = 29,  /* number of len codes */
     11 	Nlitlen   = Nlit + Nlen + 3, /* litlen codes + block end + 2 unused */
     12 	Ndist     = 30,  /* number of distance codes */
     13 	Nclen     = 19   /* number of code length codes */
     14 };
     15 
     16 typedef struct {
     17 	ushort count[CodeBits]; /* code length -> count */
     18 	ushort symbol[Nlitlen]; /* symbols ordered by code length */
     19 } Huff;
     20 
     21 typedef struct {
     22 	uchar *src;
     23 	uchar *dst;
     24 
     25 	uint bits;
     26 	uint nbits;
     27 
     28 	Huff lhuff; /* dynamic lit/len huffman code tree */
     29 	Huff dhuff; /* dynamic distance huffman code tree */
     30 } Stream;
     31 
     32 
     33 static Huff lhuff; /* fixed lit/len huffman code tree */
     34 static Huff dhuff; /* fixed distance huffman code tree */
     35 
     36 /* base offset tables */
     37 static ushort lenbase[Nlen];
     38 static ushort distbase[Ndist];
     39 
     40 /* extra bits tables */
     41 static uchar lenbits[Nlen] = {
     42 	0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
     43 	1,  1,  2,  2,  2,  2,  3,  3,  3,  3,
     44 	4,  4,  4,  4,  5,  5,  5,  5,  0
     45 };
     46 static uchar distbits[Ndist] = {
     47 	0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
     48 	4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
     49 	9,  9, 10, 10, 11, 11, 12, 12, 13, 13
     50 };
     51 
     52 /* ordering of code lengths */
     53 static uchar clenorder[Nclen] = {
     54 	16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
     55 };
     56 
     57 static void init_base_tables(void) {
     58 	uint base;
     59 	int i;
     60 
     61 	for (base = 3, i = 0; i < Nlen; i++) {
     62 		lenbase[i] = base;
     63 		base += 1 << lenbits[i];
     64 	}
     65 	lenbase[Nlen-1]--; /* deflate bug */
     66 	for (base = 1, i = 0; i < Ndist; i++) {
     67 		distbase[i] = base;
     68 		base += 1 << distbits[i];
     69 	}
     70 }
     71 
     72 static void init_fixed_huffs(void) {
     73 	int i;
     74 
     75 	lhuff.count[7] = 24;
     76 	lhuff.count[8] = 152;
     77 	lhuff.count[9] = 112;
     78 	for (i = 0; i < 24; i++)
     79 		lhuff.symbol[i] = 256 + i;
     80 	for (i = 0; i < 144; i++)
     81 		lhuff.symbol[24 + i] = i;
     82 	for (i = 0; i < 8; i++)
     83 		lhuff.symbol[24 + 144 + i] = 280 + i;
     84 	for (i = 0; i < 112; i++)
     85 		lhuff.symbol[24 + 144 + 8 + i] = 144 + i;
     86 	dhuff.count[5] = Ndist;
     87 	for (i = 0; i < Ndist; i++)
     88 		dhuff.symbol[i] = i;
     89 }
     90 
     91 /* build huffman code tree from code lengths */
     92 static void build_huff(Huff *h, const uchar *lens, int n) {
     93 	int offs[CodeBits];
     94 	int i, sum;
     95 
     96 	/* count code lengths and calc first code (offs) for each length */
     97 	for (i = 0; i < CodeBits; i++)
     98 		h->count[i] = 0;
     99 	for (i = 0; i < n; i++)
    100 		h->count[lens[i]]++;
    101 	h->count[0] = 0;
    102 	for (sum = 0, i = 0; i < CodeBits; i++) {
    103 		offs[i] = sum;
    104 		sum += h->count[i];
    105 	}
    106 	/* sort symbols by code length */
    107 	for (i = 0; i < n; i++)
    108 		if (lens[i])
    109 			h->symbol[offs[lens[i]]++] = i;
    110 }
    111 
    112 /* get one bit from stream */
    113 static uint getbit(Stream *s) {
    114 	uint bit;
    115 
    116 	if (!s->nbits--) {
    117 		s->bits = *s->src++;
    118 		s->nbits = 7;
    119 	}
    120 	bit = s->bits & 1;
    121 	s->bits >>= 1;
    122 	return bit;
    123 }
    124 
    125 /* get n bits from stream */
    126 static uint getbits(Stream *s, int n) {
    127 	uint bits = 0;
    128 	int i;
    129 
    130 	for (i = 0; i < n; i++)
    131 		bits |= getbit(s) << i;
    132 	return bits;
    133 }
    134 
    135 /* decode a symbol from stream with huffman code tree */
    136 static uint decode_symbol(Stream *s, Huff *h) {
    137 	int sum = 0, cur = 0;
    138 	ushort *count = h->count + 1;
    139 
    140 	for (;;) {
    141 		cur |= getbit(s);
    142 		sum += *count;
    143 		cur -= *count;
    144 		if (cur < 0)
    145 			break;
    146 		cur <<= 1;
    147 		count++;
    148 	}
    149 	return h->symbol[sum + cur];
    150 }
    151 
    152 /* decode dynamic huffman code trees from stream */
    153 static void decode_huffs(Stream *s) {
    154 	Huff chuff;
    155 	uchar lens[Nlitlen+Ndist];
    156 	uint nlit, ndist, nclen;
    157 	uint i;
    158 
    159 	nlit = 257 + getbits(s, 5);
    160 	ndist = 1 + getbits(s, 5);
    161 	nclen = 4 + getbits(s, 4);
    162 	/* build code length code tree */
    163 	for (i = 0; i < Nclen; i++)
    164 		lens[i] = 0;
    165 	for (i = 0; i < nclen; i++)
    166 		lens[clenorder[i]] = getbits(s, 3);
    167 	build_huff(&chuff, lens, Nclen);
    168 	/* decode code lengths for the dynamic code tree */
    169 	for (i = 0; i < nlit + ndist; ) {
    170 		uint sym = decode_symbol(s, &chuff);
    171 		uint len;
    172 		uchar c;
    173 
    174 		if (sym < 16) {
    175 			lens[i++] = sym;
    176 		} else if (sym == 16) {
    177 			/* copy previous code length 3-6 times */
    178 			c = lens[i - 1];
    179 			for (len = 3 + getbits(s, 2); len; len--)
    180 				lens[i++] = c;
    181 		} else if (sym == 17) {
    182 			/* repeat 0 for 3-10 times */
    183 			for (len = 3 + getbits(s, 3); len; len--)
    184 				lens[i++] = 0;
    185 		} else if (sym == 18) {
    186 			/* repeat 0 for 11-138 times */
    187 			for (len = 11 + getbits(s, 7); len; len--)
    188 				lens[i++] = 0;
    189 		}
    190 	}
    191 	/* build dynamic huffman code tree */
    192 	build_huff(&s->lhuff, lens, nlit);
    193 	build_huff(&s->dhuff, lens + nlit, ndist);
    194 }
    195 
    196 /* decode a block of data from stream with huffman code trees */
    197 static void decode_block(Stream *s, Huff *lhuff, Huff *dhuff) {
    198 	uint sym;
    199 
    200 	for (;;) {
    201 		sym = decode_symbol(s, lhuff);
    202 		if (sym == 256)
    203 			return;
    204 		if (sym < 256)
    205 			*s->dst++ = sym;
    206 		else {
    207 			uint len, dist;
    208 
    209 			sym -= 257;
    210 			len = lenbase[sym] + getbits(s, lenbits[sym]);
    211 			sym = decode_symbol(s, dhuff);
    212 			dist = distbase[sym] + getbits(s, distbits[sym]);
    213 			/* copy match */
    214 			while (len--) {
    215 				*s->dst = *(s->dst - dist);
    216 				s->dst++;
    217 			}
    218 		}
    219 	}
    220 }
    221 
    222 static void inflate_uncompressed_block(Stream *s) {
    223 	uint len;
    224 
    225 	s->nbits = 0; /* start block on a byte boundary */
    226 	len = (s->src[1] << 8) | s->src[0];
    227 	s->src += 4;
    228 	while (len--)
    229 		*s->dst++ = *s->src++;
    230 }
    231 
    232 static void inflate_fixed_block(Stream *s) {
    233 	decode_block(s, &lhuff, &dhuff);
    234 }
    235 
    236 static void inflate_dynamic_block(Stream *s) {
    237 	decode_huffs(s);
    238 	decode_block(s, &s->lhuff, &s->dhuff);
    239 }
    240 
    241 
    242 /* extern */
    243 
    244 /* inflate stream from src to dst, return end pointer */
    245 void *inflate(void *dst, void *src) {
    246 	Stream s;
    247 	uint final;
    248 
    249 	/* initialize global (static) data */
    250 	init_base_tables();
    251 	init_fixed_huffs();
    252 
    253 	s.src = src;
    254 	s.dst = dst;
    255 	s.nbits = 0;
    256 	do {
    257 		final = getbit(&s);
    258 		switch (getbits(&s, 2)) {
    259 		case 0:
    260 			inflate_uncompressed_block(&s);
    261 			break;
    262 		case 1:
    263 			inflate_fixed_block(&s);
    264 			break;
    265 		case 2:
    266 			inflate_dynamic_block(&s);
    267 			break;
    268 		}
    269 	} while (!final);
    270 	return s.dst;
    271 }
    272 
    273 
    274 /* simple test */
    275 
    276 #include <stdlib.h>
    277 #include <stdio.h>
    278 
    279 void *readall(FILE *in) {
    280 	uint len = 1 << 22;
    281 	void *buf;
    282 
    283 	buf = malloc(len);
    284 	fread(buf, 1, len, in);
    285 	fclose(in);
    286 	return buf;
    287 }
    288 
    289 int main(void) {
    290 	uint len = 1 << 24;
    291 	uchar *src, *dst;
    292 
    293 	src = readall(stdin);
    294 	dst = malloc(len);
    295 	len = (uchar *)inflate(dst, src) - dst;
    296 	fprintf(stderr, "decompressed %u bytes\n", len);
    297 	fwrite(dst, 1, len, stdout);
    298 	free(dst);
    299 	free(src);
    300 	return 0;
    301 }